From f5579669d81e663f0c2ef19dee6124e6d70d6491 Mon Sep 17 00:00:00 2001 From: Abraham Lazaro Martinez Date: Mon, 28 Apr 2025 17:31:29 +0000 Subject: [PATCH 1/7] merge with previous branch --- .../src/genkit/plugins/vertex_ai/__init__.py | 2 + .../plugins/vertex_ai/models/retriever.py | 312 ++++++++++++++++++ .../plugins/vertex_ai/models/vectorstore.py | 57 ++++ .../vertex_ai/vector_search/vector_search.py | 234 +++++++++++++ .../vertex-ai-vector-search-bigquery/LICENSE | 201 +++++++++++ .../README.md | 29 ++ .../pyproject.toml | 39 +++ .../src/sample.py | 102 ++++++ .../vertex-ai-vector-search-firestore/LICENSE | 201 +++++++++++ .../README.md | 28 ++ .../pyproject.toml | 39 +++ .../src/sample.py | 99 ++++++ py/uv.lock | 48 +++ 13 files changed, 1391 insertions(+) create mode 100644 py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py create mode 100644 py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/vectorstore.py create mode 100644 py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py create mode 100644 py/samples/vertex-ai-vector-search-bigquery/LICENSE create mode 100644 py/samples/vertex-ai-vector-search-bigquery/README.md create mode 100644 py/samples/vertex-ai-vector-search-bigquery/pyproject.toml create mode 100644 py/samples/vertex-ai-vector-search-bigquery/src/sample.py create mode 100644 py/samples/vertex-ai-vector-search-firestore/LICENSE create mode 100644 py/samples/vertex-ai-vector-search-firestore/README.md create mode 100644 py/samples/vertex-ai-vector-search-firestore/pyproject.toml create mode 100644 py/samples/vertex-ai-vector-search-firestore/src/sample.py diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py index c0ac5edf03..c635d21132 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py @@ -26,6 +26,7 @@ from genkit.plugins.vertex_ai.gemini import GeminiVersion from genkit.plugins.vertex_ai.imagen import ImagenOptions, ImagenVersion from genkit.plugins.vertex_ai.plugin_api import VertexAI, vertexai_name +from genkit.plugins.vertex_ai.vector_search.vector_search import VertexAIVectorSearch def package_name() -> str: @@ -46,4 +47,5 @@ def package_name() -> str: GeminiVersion.__name__, ImagenVersion.__name__, ImagenOptions.__name__, + VertexAIVectorSearch.__name__, ] diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py new file mode 100644 index 0000000000..8828749359 --- /dev/null +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py @@ -0,0 +1,312 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import json +from abc import ABC, abstractmethod +from typing import Any + +import structlog +from google.cloud import aiplatform_v1, bigquery, firestore +from google.cloud.aiplatform_v1 import FindNeighborsRequest, IndexDatapoint, Neighbor +from pydantic import BaseModel, Field, ValidationError + +from genkit.ai import Genkit +from genkit.blocks.document import Document +from genkit.core.typing import Embedding +from genkit.types import ActionRunContext, RetrieverRequest, RetrieverResponse + +logger = structlog.get_logger(__name__) + + +class DocRetriever(ABC): + """Abstract base class for Vertex AI Vector Search document retrieval. + + This class outlines the core workflow for retrieving relevant documents. + It is not intended to be instantiated directly. Subclasses must implement + the abstract methods to provide concrete retrieval logic depending of the + technology used. + + Attributes: + ai: The Genkit instance. + name: The name of this retriever instance. + match_service_client: The Vertex AI Matching Engine client. + embedder: The name of the embedder to use for generating embeddings. + embedder_options: Options to pass to the embedder. + """ + def __init__( + self, + ai: Genkit, + name: str, + match_service_client: aiplatform_v1.MatchServiceAsyncClient, + embedder: str, + embedder_options: dict[str, Any] | None = None, + ) -> None: + """Initializes the DocRetriever. + + Args: + ai: The Genkit application instance. + name: The name of this retriever instance. + match_service_client: The Vertex AI Matching Engine client. + embedder: The name of the embedder to use for generating embeddings. + embedder_options: Optional dictionary of options to pass to the embedder. + """ + self.ai = ai + self.name = name + self._match_service_client = match_service_client + self.embedder = embedder + self.embedder_options = embedder_options or {} + + async def retrieve(self, request: RetrieverRequest, _: ActionRunContext) -> RetrieverResponse: + """Retrieves documents based on a given query. + + Args: + request: The retrieval request containing the query. + _: The ActionRunContext (unused in this method). + + Returns: + A RetrieverResponse object containing the retrieved documents. + """ + document = Document.from_document_data(document_data=request.query) + + embeddings = await self.ai.embed( + embedder=self.embedder, + documents=[document], + options=self.embedder_options, + ) + + if self.embedder_options: + top_k = self.embedder_options.get('limit') or 3 + else: + top_k = 3 + + docs = await self._get_closest_documents( + request=request, + top_k=top_k, + query_embeddings=embeddings.embeddings[0], + ) + + return RetrieverResponse(documents=[d.document for d in docs]) + + async def _get_closest_documents( + self, request: RetrieverRequest, top_k: int, query_embeddings: Embedding + ) -> list[Document]: + """Retrieves the closest documents from the vector search index based on query embeddings. + + Args: + request: The retrieval request containing the query and metadata. + top_k: The number of nearest neighbors to retrieve. + query_embeddings: The embedding of the query. + + Returns: + A list of Document objects representing the closest documents. + + Raises: + AttributeError: If the request does not contain the necessary + index endpoint path in its metadata. + """ + metadata = request.query.metadata + if not metadata or 'index_endpoint_path' not in metadata: + raise AttributeError('Request provides no data about index endpoint path') + + index_endpoint_path = metadata['index_endpoint_path'] + deployed_index_id = metadata['deployed_index_id'] + + nn_request = FindNeighborsRequest( + index_endpoint=index_endpoint_path, + deployed_index_id=deployed_index_id, + queries=[ + FindNeighborsRequest.Query( + datapoint=IndexDatapoint(feature_vector=query_embeddings.embedding), + neighbor_count=top_k, + ) + ], + ) + + response = await self._match_service_client.find_neighbors(request=nn_request) + + return await self._retrieve_neighbours_data_from_db(neighbours=response.nearest_neighbors[0].neighbors) + + @abstractmethod + async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> list[Document]: + """Retrieves document data from the database based on neighbor information. + + This method must be implemented by subclasses to define how document + data is fetched from the database using the provided neighbor information. + + Args: + neighbours: A list of Neighbor objects representing the nearest neighbors + found in the vector search index. + + Returns: + A list of Document objects containing the data for the retrieved documents. + """ + raise NotImplementedError + + +class BigQueryRetriever(DocRetriever): + """Retrieves documents from a BigQuery table. + + This class extends DocRetriever to fetch document data from a specified BigQuery + dataset and table. It constructs a query to retrieve documents based on the IDs + obtained from nearest neighbor search results. + + Attributes: + bq_client: The BigQuery client to use for querying. + dataset_id: The ID of the BigQuery dataset. + table_id: The ID of the BigQuery table. + """ + def __init__( + self, bq_client: bigquery.Client, dataset_id: str, table_id: str, *args, **kwargs, + ) -> None: + """Initializes the BigQueryRetriever. + + Args: + bq_client: The BigQuery client to use for querying. + dataset_id: The ID of the BigQuery dataset. + table_id: The ID of the BigQuery table. + *args: Additional positional arguments to pass to the parent class. + **kwargs: Additional keyword arguments to pass to the parent class. + """ + super().__init__(*args, **kwargs) + self.bq_client = bq_client + self.dataset_id = dataset_id + self.table_id = table_id + + async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> list[Document]: + """Retrieves document data from the BigQuery table for the given neighbors. + + Constructs and executes a BigQuery query to fetch document data based on + the IDs obtained. Handles potential errors during query execution and + document parsing. + + Args: + neighbours: A list of Neighbor objects representing the nearest neighbors. + Each neighbor should contain a datapoint with a datapoint_id. + + Returns: + A list of Document objects containing the retrieved document data. + Returns an empty list if no IDs are found in the neighbors or if the + query fails. + """ + ids = [ + n.datapoint.datapoint_id + for n in neighbours + if n.datapoint and n.datapoint.datapoint_id + ] + + if not ids: + return [] + + query = f""" + SELECT * FROM `{self.dataset_id}.{self.table_id}` + WHERE id IN UNNEST(@ids) + """ + + job_config = bigquery.QueryJobConfig( + query_parameters=[bigquery.ArrayQueryParameter('ids', 'STRING', ids)], + ) + + try: + query_job = self.bq_client.query(query, job_config=job_config) + rows = query_job.result() + except Exception as e: + await logger.aerror('Failed to execute BigQuery query: %s', e) + return [] + + documents: list[Document] = [] + + for row in rows: + try: + doc_data = { + 'content': json.loads(row['content']), + } + if row.get('metadata'): + doc_data['metadata'] = json.loads(row['metadata']) + + documents.append(Document(**doc_data)) + except (ValidationError, json.JSONDecodeError, Exception) as error: + doc_id = row.get('id', '') + await logger.awarning(f'Failed to parse document data for document with ID {doc_id}: {error}') + + return documents + + +class FirestoreRetriever(DocRetriever): + """Retrieves documents from a Firestore collection. + + This class extends DocRetriever to fetch document data from a specified Firestore + collection. It retrieves documents based on IDs obtained from nearest neighbor + search results. + + Attributes: + db: The Firestore client. + collection_name: The name of the Firestore collection. + """ + def __init__( + self, firestore_client: firestore.AsyncClient, collection_name: str, *args, **kwargs, + ) -> None: + """Initializes the FirestoreRetriever. + + Args: + firestore_client: The Firestore client to use for querying. + collection_name: The name of the Firestore collection. + *args: Additional positional arguments to pass to the parent class. + **kwargs: Additional keyword arguments to pass to the parent class. + """ + super().__init__(*args, **kwargs) + self.db = firestore_client + self.collection_name = collection_name + + async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> list[Document]: + """Retrieves document data from the Firestore collection for the given neighbors. + + Fetches document data from Firestore based on the IDs of the nearest neighbors. + Handles potential errors during document retrieval and data parsing. + + Args: + neighbours: A list of Neighbor objects representing the nearest neighbors. + Each neighbor should contain a datapoint with a datapoint_id. + + Returns: + A list of Document objects containing the retrieved document data. + Returns an empty list if no documents are found for the given IDs. + """ + documents: list[Document] = [] + + for neighbor in neighbours: + doc_ref = self.db.collection(self.collection_name).document(document_id=neighbor.datapoint.datapoint_id) + doc_snapshot = await doc_ref.get() + + if doc_snapshot.exists: + doc_data = doc_snapshot.to_dict() or {} + + try: + documents.append(Document(**doc_data)) + except ValidationError as e: + await logger.awarning( + f'Failed to parse document data for ID {neighbor.datapoint.datapoint_id}: {e}' + ) + + return documents + + +class RetrieverOptionsSchema(BaseModel): + """Schema for retriver options. + + Attributes: + limit: Number of documents to retrieve. + """ + limit: int | None = Field(title='Number of documents to retrieve', default=None) diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/vectorstore.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/vectorstore.py new file mode 100644 index 0000000000..6be441f7a3 --- /dev/null +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/vectorstore.py @@ -0,0 +1,57 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import sys + +if sys.version_info < (3, 11): # noqa + from strenum import StrEnum # noqa +else: # noqa + from enum import StrEnum # noqa + +from pydantic import BaseModel, Field + + +class IndexShardSize(StrEnum): + """Defines the size of each shard in the index.""" + SMALL = 'SHARD_SIZE_SMALL' + MEDIUM = 'SHARD_SIZE_MEDIUM' + LARGE = 'SHARD_SIZE_LARGE' + + +class FeatureNormType(StrEnum): + """Specifies the normalization applied to feature vectors.""" + NONE = 'NONE' + UNIT_L2_NORMALIZED = 'UNIT_L2_NORM' + + +class DistanceMeasureType(StrEnum): + """Defines the available distance measure methods.""" + SQUARED_L2 = 'SQUARED_L2_DISTANCE' + L2 = 'L2_DISTANCE' + COSINE = 'COSINE_DISTANCE' + DOT_PRODUCT = 'DOT_PRODUCT_DISTANCE' + + +class IndexConfig(BaseModel): + """Defines the configurations of indexes.""" + dimensions: int = 128 + approximate_neighbors_count: int = Field(default=100, alias='approximateNeighborsCount') + distance_measure_type: DistanceMeasureType | str = Field( + default=DistanceMeasureType.COSINE, alias='distanceMeasureType' + ) + feature_norm_type: FeatureNormType | str = Field(default=FeatureNormType.NONE, alias='featureNormType') + shard_size: IndexShardSize | str = Field(default=IndexShardSize.MEDIUM, alias='shardSize') + algorithm_config: dict | None = Field(default=None, alias='algorithmConfig') diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py new file mode 100644 index 0000000000..006d37bb0f --- /dev/null +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py @@ -0,0 +1,234 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import Any + +import structlog +from google.auth.credentials import Credentials +from google.cloud import aiplatform_v1, storage +from google.genai.types import HttpOptions, HttpOptionsDict, Operation + +from genkit.ai import GENKIT_CLIENT_HEADER, GenkitRegistry, Plugin +from genkit.plugins.vertex_ai import vertexai_name +from genkit.plugins.vertex_ai.models.retriever import ( + DocRetriever, + RetrieverOptionsSchema, +) +from genkit.plugins.vertex_ai.models.vectorstore import IndexConfig + +logger = structlog.get_logger(__name__) + + +class VertexAIVectorSearch(Plugin): + """A plugin for integrating VertexAI Vector Search. + + This class registers VertexAI Vector Stores within a registry, + and allows interaction to retrieve similar documents. + """ + + name: str = 'vertexAIVectorSearch' + + def __init__( + self, + retriever: DocRetriever, + retriever_extra_args: dict[str, Any] | None = None, + credentials: Credentials | None = None, + project: str | None = None, + location: str | None = 'us-central1', + embedder: str | None = None, + embedder_options: dict[str, Any] | None = None, + http_options: HttpOptions | HttpOptionsDict | None = None, + ) -> None: + """Initializes the VertexAIVectorSearch plugin. + + Args: + retriever: The DocRetriever class to use for retrieving documents. + retriever_extra_args: Optional dictionary of extra arguments to pass to the + retriever's constructor. + credentials: Optional Google Cloud credentials to use. If not provided, + the default application credentials will be used. + project: Optional Google Cloud project ID. If not provided, it will be + inferred from the credentials. + location: Optional Google Cloud location (region). Defaults to + 'us-central1'. + embedder: Optional identifier for the embedding model to use. + embedder_options: Optional dictionary of options to pass to the embedding + model. + http_options: Optional HTTP options for API requests. + """ + http_options = _inject_attribution_headers(http_options=http_options) + + self.project = project + self.location = location + + self.embedder = embedder + self.embedder_options = embedder_options + + self.retriever_cls = retriever + self.retriever_extra_args = retriever_extra_args or {} + + self._storage_client = storage.Client( + project=self.project, + credentials=credentials, + extra_headers=http_options.headers, + ) + self._index_client = aiplatform_v1.IndexServiceAsyncClient( + credentials=credentials, + ) + self._endpoint_client = aiplatform_v1.IndexEndpointServiceAsyncClient(credentials=credentials) + self._match_service_client = aiplatform_v1.MatchServiceAsyncClient( + credentials=credentials, + ) + + async def create_index( + self, + display_name: str, + description: str | None, + index_config: IndexConfig | None = None, + contents_delta_uri: str | None = None, + ) -> None: + """Creates a Vertex AI Vector Search index. + + Args: + display_name: The display name for the index. + description: Optional description of the index. + index_config: Optional configuration for the index. If not provided, a + default configuration is used. + contents_delta_uri: Optional URI of the Cloud Storage location for the + contents delta. + """ + if not index_config: + index_config = IndexConfig() + + index = aiplatform_v1.Index() + index.display_name = display_name + index.description = description + index.metadata = { + 'config': index_config.model_dump(), + 'contentsDeltaUri': contents_delta_uri, + } + + request = aiplatform_v1.CreateIndexRequest( + parent=self.index_location_path, + index=index, + ) + + operation = await self._index_client.create_index(request=request) + + logger.debug(await operation.result()) + + async def deploy_index(self, index_name: str, endpoint_name: str) -> None: + """Deploys an index to an endpoint. + + Args: + index_name: The name of the index to deploy. + endpoint_name: The name of the endpoint to deploy the index to. + """ + deployed_index = aiplatform_v1.DeployedIndex() + deployed_index.id = index_name + deployed_index.index = self.get_index_path(index_name=index_name) + + request = aiplatform_v1.DeployIndexRequest( + index_endpoint=endpoint_name, + deployed_index=deployed_index, + ) + + operation = self._endpoint_client.deploy_index(request=request) + + logger.debug(await operation.result()) + + def upload_jsonl_file(self, local_path: str, bucket_name: str, destination_location: str) -> Operation: + """Uploads a JSONL file to Cloud Storage. + + Args: + local_path: The local path to the JSONL file. + bucket_name: The name of the Cloud Storage bucket. + destination_location: The destination path within the bucket. + + Returns: + The upload operation. + """ + bucket = self._storage_client.bucket(bucket_name=bucket_name) + blob = bucket.blob(destination_location) + blob.upload_from_filename(local_path) + + def get_index_path(self, index_name: str) -> str: + """Gets the full resource path of an index. + + Args: + index_name: The name of the index. + + Returns: + The full resource path of the index. + """ + return self._index_client.index_path(project=self.project, location=self.location, index=index_name) + + @property + def index_location_path(self) -> str: + """Gets the resource path of the index location. + + Returns: + The resource path of the index location. + """ + return self._index_client.common_location_path(project=self.project, location=self.location) + + def initialize(self, ai: GenkitRegistry) -> None: + """Initialize plugin with the retriver specified. + + Register actions with the registry making them available for use in the Genkit framework. + + Args: + ai: The registry to register actions with. + """ + retriever = self.retriever_cls( + ai=ai, + name=self.name, + match_service_client=self._match_service_client, + embedder=self.embedder, + embedder_options=self.embedder_options, + **self.retriever_extra_args, + ) + + return ai.define_retriever( + name=vertexai_name(self.name), + config_schema=RetrieverOptionsSchema, + fn=retriever.retrieve, + ) + + +def _inject_attribution_headers(http_options) -> HttpOptions: + """Adds genkit client info to the appropriate http headers.""" + if not http_options: + http_options = HttpOptions() + else: + if isinstance(http_options, dict): + http_options = HttpOptions(**http_options) + + if not http_options.headers: + http_options.headers = {} + + if 'x-goog-api-client' not in http_options.headers: + http_options.headers['x-goog-api-client'] = GENKIT_CLIENT_HEADER + else: + http_options.headers['x-goog-api-client'] += f' {GENKIT_CLIENT_HEADER}' + + if 'user-agent' not in http_options.headers: + http_options.headers['user-agent'] = GENKIT_CLIENT_HEADER + else: + http_options.headers['user-agent'] += f' {GENKIT_CLIENT_HEADER}' + + return http_options diff --git a/py/samples/vertex-ai-vector-search-bigquery/LICENSE b/py/samples/vertex-ai-vector-search-bigquery/LICENSE new file mode 100644 index 0000000000..2205396735 --- /dev/null +++ b/py/samples/vertex-ai-vector-search-bigquery/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/py/samples/vertex-ai-vector-search-bigquery/README.md b/py/samples/vertex-ai-vector-search-bigquery/README.md new file mode 100644 index 0000000000..47d03ab707 --- /dev/null +++ b/py/samples/vertex-ai-vector-search-bigquery/README.md @@ -0,0 +1,29 @@ +# Vertex AI - Vector Search BigQuery + +An example demonstrating the use Vector Search API with BigQuery retriever for Vertex AI + +## Setup environment + +1. Install [GCP CLI](https://cloud.google.com/sdk/docs/install). +2. Run the following code to connect to VertexAI. +```bash +gcloud auth application-default login` +``` +3. Set the following env vars to run the sample +``` +export LOCATION='' +export PROJECT_ID='' +export BIGQUERY_DATASET='' +export BIGQUERY_TABLE='' +export VECTOR_SEARCH_DEPLOYED_INDEX_ID='' +export VECTOR_SEARCH_INDEX_ENDPOINT_ID='' +export VECTOR_SEARCH_INDEX_ID='' +export VECTOR_SEARCH_PUBLIC_DOMAIN_NAME='' +``` +4. Run the sample. + +## Run the sample + +```bash +genkit start -- uv run src/sample.py +``` diff --git a/py/samples/vertex-ai-vector-search-bigquery/pyproject.toml b/py/samples/vertex-ai-vector-search-bigquery/pyproject.toml new file mode 100644 index 0000000000..7eae7e480c --- /dev/null +++ b/py/samples/vertex-ai-vector-search-bigquery/pyproject.toml @@ -0,0 +1,39 @@ +[project] +authors = [{ name = "Google" }] +classifiers = [ + "Development Status :: 3 - Alpha", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries", +] +dependencies = [ + "genkit", + "genkit-plugin-vertex-ai", + "pydantic>=2.10.5", + "structlog>=25.2.0", + "google-cloud-bigquery", + "strenum>=0.4.15; python_version < '3.11'", +] +description = "An example demonstrating the use Vector Search API with BigQuery retriever for Vertex AI" +license = { text = "Apache-2.0" } +name = "vertex-ai-vector-search-bigquery" +readme = "README.md" +requires-python = ">=3.10" +version = "0.1.0" + +[build-system] +build-backend = "hatchling.build" +requires = ["hatchling"] + +[tool.hatch.build.targets.wheel] +packages = ["src/sample"] diff --git a/py/samples/vertex-ai-vector-search-bigquery/src/sample.py b/py/samples/vertex-ai-vector-search-bigquery/src/sample.py new file mode 100644 index 0000000000..71b7713496 --- /dev/null +++ b/py/samples/vertex-ai-vector-search-bigquery/src/sample.py @@ -0,0 +1,102 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import time + +from google.cloud import aiplatform, bigquery +from pydantic import BaseModel + +from genkit.ai import Genkit +from genkit.blocks.document import Document +from genkit.plugins.vertex_ai import ( + EmbeddingModels, + VertexAI, + VertexAIVectorSearch, + vertexai_name, +) +from genkit.plugins.vertex_ai.models.retriever import BigQueryRetriever + +LOCATION = os.getenv('LOCATION') +PROJECT_ID = os.getenv('PROJECT_ID') +BIGQUERY_DATASET = os.getenv('BIGQUERY_DATASET') +BIGQUERY_TABLE = os.getenv('BIGQUERY_TABLE') +VECTOR_SEARCH_DEPLOYED_INDEX_ID = os.getenv('VECTOR_SEARCH_DEPLOYED_INDEX_ID') +VECTOR_SEARCH_INDEX_ENDPOINT_ID = os.getenv('VECTOR_SEARCH_INDEX_ENDPOINT_ID') +VECTOR_SEARCH_INDEX_ID = os.getenv('VECTOR_SEARCH_INDEX_ID') +VECTOR_SEARCH_PUBLIC_DOMAIN_NAME = os.getenv('VECTOR_SEARCH_PUBLIC_DOMAIN_NAME') + + +bq_client = bigquery.Client(project=PROJECT_ID) +aiplatform.init(project=PROJECT_ID, location=LOCATION) + + +ai = Genkit( + plugins=[ + VertexAI(), + VertexAIVectorSearch( + retriever=BigQueryRetriever, + retriever_extra_args={ + 'bq_client': bq_client, + 'dataset_id': BIGQUERY_DATASET, + 'table_id': BIGQUERY_TABLE, + }, + embedder=EmbeddingModels.TEXT_EMBEDDING_004_ENG, + embedder_options={'taskType': 'RETRIEVAL_DOCUMENT'}, + ), + ] +) + + +class QueryFlowInputSchema(BaseModel): + query: str + k: int + + +class QueryFlowOutputSchema(BaseModel): + result: list[dict] + length: int + time: int + + +@ai.flow(name='queryFlow') +async def query_flow(_input: QueryFlowInputSchema) -> QueryFlowOutputSchema: + start_time = time.time() + query_document = Document.from_text(text=_input.query) + + result: list[Document] = await ai.retrieve( + retriever=vertexai_name(VECTOR_SEARCH_INDEX_ID), + query=query_document, + ) + + end_time = time.time() + + duration = int(end_time - start_time) + + result_data = [] + for doc in result: + result_data.append({ + 'text': doc.content[0].root.text, + 'distance': doc.metadata.get('distance'), + }) + + result_data = sorted(result_data, key=lambda x: x['distance']) + + return QueryFlowOutputSchema( + result=result_data, + length=len(result_data), + time=duration, + ) diff --git a/py/samples/vertex-ai-vector-search-firestore/LICENSE b/py/samples/vertex-ai-vector-search-firestore/LICENSE new file mode 100644 index 0000000000..2205396735 --- /dev/null +++ b/py/samples/vertex-ai-vector-search-firestore/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/py/samples/vertex-ai-vector-search-firestore/README.md b/py/samples/vertex-ai-vector-search-firestore/README.md new file mode 100644 index 0000000000..69299cc371 --- /dev/null +++ b/py/samples/vertex-ai-vector-search-firestore/README.md @@ -0,0 +1,28 @@ +# Vertex AI Vector Search Firestore + +An example demonstrating the use Vector Search API with Firestore retriever for Vertex AI + +## Setup environment + +1. Install [GCP CLI](https://cloud.google.com/sdk/docs/install). +2. Run the following code to connect to VertexAI. +```bash +gcloud auth application-default login` +``` +3. Set the following env vars to run the sample +``` +export LOCATION='' +export PROJECT_ID='' +export FIRESTORE_COLLECTION='' +export VECTOR_SEARCH_DEPLOYED_INDEX_ID='' +export VECTOR_SEARCH_INDEX_ENDPOINT_ID='' +export VECTOR_SEARCH_INDEX_ID='' +export VECTOR_SEARCH_PUBLIC_DOMAIN_NAME='' +``` +4. Run the sample. + +## Run the sample + +```bash +genkit start -- uv run src/sample.py +``` diff --git a/py/samples/vertex-ai-vector-search-firestore/pyproject.toml b/py/samples/vertex-ai-vector-search-firestore/pyproject.toml new file mode 100644 index 0000000000..3413399903 --- /dev/null +++ b/py/samples/vertex-ai-vector-search-firestore/pyproject.toml @@ -0,0 +1,39 @@ +[project] +authors = [{ name = "Google" }] +classifiers = [ + "Development Status :: 3 - Alpha", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries", +] +dependencies = [ + "genkit", + "genkit-plugin-vertex-ai", + "pydantic>=2.10.5", + "structlog>=25.2.0", + "google-cloud-firestore", + "strenum>=0.4.15; python_version < '3.11'", +] +description = "An example demonstrating the use Vector Search API with Firestore retriever for Vertex AI" +license = { text = "Apache-2.0" } +name = "vertex-ai-vector-search-firestore" +readme = "README.md" +requires-python = ">=3.10" +version = "0.1.0" + +[build-system] +build-backend = "hatchling.build" +requires = ["hatchling"] + +[tool.hatch.build.targets.wheel] +packages = ["src/sample"] diff --git a/py/samples/vertex-ai-vector-search-firestore/src/sample.py b/py/samples/vertex-ai-vector-search-firestore/src/sample.py new file mode 100644 index 0000000000..b34fdba291 --- /dev/null +++ b/py/samples/vertex-ai-vector-search-firestore/src/sample.py @@ -0,0 +1,99 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import time + +from google.cloud import aiplatform, firestore +from pydantic import BaseModel + +from genkit.ai import Genkit +from genkit.blocks.document import Document +from genkit.plugins.vertex_ai import ( + EmbeddingModels, + VertexAI, + VertexAIVectorSearch, + vertexai_name, +) +from genkit.plugins.vertex_ai.models.retriever import FirestoreRetriever + +LOCATION = os.getenv('LOCATION') +PROJECT_ID = os.getenv('PROJECT_ID') +FIRESTORE_COLLECTION = os.getenv('FIRESTORE_COLLECTION') +VECTOR_SEARCH_DEPLOYED_INDEX_ID = os.getenv('VECTOR_SEARCH_DEPLOYED_INDEX_ID') +VECTOR_SEARCH_INDEX_ENDPOINT_ID = os.getenv('VECTOR_SEARCH_INDEX_ENDPOINT_ID') +VECTOR_SEARCH_INDEX_ID = os.getenv('VECTOR_SEARCH_INDEX_ID') +VECTOR_SEARCH_PUBLIC_DOMAIN_NAME = os.getenv('VECTOR_SEARCH_PUBLIC_DOMAIN_NAME') + +firestore_client = firestore.Client(project=PROJECT_ID) +aiplatform.init(project=PROJECT_ID, location=LOCATION) + + +ai = Genkit( + plugins=[ + VertexAI(), + VertexAIVectorSearch( + retriever=FirestoreRetriever, + retriever_extra_args={ + 'firestore_client': firestore_client, + 'collection_name': FIRESTORE_COLLECTION, + }, + embedder=EmbeddingModels.TEXT_EMBEDDING_004_ENG, + embedder_options={'taskType': 'RETRIEVAL_DOCUMENT'}, + ), + ] +) + + +class QueryFlowInputSchema(BaseModel): + query: str + k: int + + +class QueryFlowOutputSchema(BaseModel): + result: list[dict] + length: int + time: int + + +@ai.flow(name='queryFlow') +async def query_flow(_input: QueryFlowInputSchema) -> QueryFlowOutputSchema: + start_time = time.time() + query_document = Document.from_text(text=_input.query) + + result: list[Document] = await ai.retrieve( + retriever=vertexai_name(VECTOR_SEARCH_INDEX_ID), + query=query_document, + ) + + end_time = time.time() + + duration = int(end_time - start_time) + + result_data = [] + for doc in result: + result_data.append({ + 'text': doc.content[0].root.text, + 'distance': doc.metadata.get('distance'), + }) + + result_data = sorted(result_data, key=lambda x: x['distance']) + + return QueryFlowOutputSchema( + result=result_data, + length=len(result_data), + time=duration, + ) diff --git a/py/uv.lock b/py/uv.lock index 0585aac076..07b6084c0d 100644 --- a/py/uv.lock +++ b/py/uv.lock @@ -36,6 +36,8 @@ members = [ "ollama-simple-embed", "short-n-long", "tool-interrupts", + "vertex-ai-vector-search-bigquery", + "vertex-ai-vector-search-firestore", ] [[package]] @@ -4617,6 +4619,52 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/63/9a/0962b05b308494e3202d3f794a6e85abe471fe3cafdbcf95c2e8c713aabd/uvloop-0.21.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a5c39f217ab3c663dc699c04cbd50c13813e31d917642d459fdcec07555cc553", size = 4660018 }, ] +[[package]] +name = "vertex-ai-vector-search-bigquery" +version = "0.1.0" +source = { editable = "samples/vertex-ai-vector-search-bigquery" } +dependencies = [ + { name = "genkit" }, + { name = "genkit-plugin-vertex-ai" }, + { name = "google-cloud-bigquery" }, + { name = "pydantic" }, + { name = "strenum", marker = "python_full_version < '3.11'" }, + { name = "structlog" }, +] + +[package.metadata] +requires-dist = [ + { name = "genkit", editable = "packages/genkit" }, + { name = "genkit-plugin-vertex-ai", editable = "plugins/vertex-ai" }, + { name = "google-cloud-bigquery" }, + { name = "pydantic", specifier = ">=2.10.5" }, + { name = "strenum", marker = "python_full_version < '3.11'", specifier = ">=0.4.15" }, + { name = "structlog", specifier = ">=25.2.0" }, +] + +[[package]] +name = "vertex-ai-vector-search-firestore" +version = "0.1.0" +source = { editable = "samples/vertex-ai-vector-search-firestore" } +dependencies = [ + { name = "genkit" }, + { name = "genkit-plugin-vertex-ai" }, + { name = "google-cloud-firestore" }, + { name = "pydantic" }, + { name = "strenum", marker = "python_full_version < '3.11'" }, + { name = "structlog" }, +] + +[package.metadata] +requires-dist = [ + { name = "genkit", editable = "packages/genkit" }, + { name = "genkit-plugin-vertex-ai", editable = "plugins/vertex-ai" }, + { name = "google-cloud-firestore" }, + { name = "pydantic", specifier = ">=2.10.5" }, + { name = "strenum", marker = "python_full_version < '3.11'", specifier = ">=0.4.15" }, + { name = "structlog", specifier = ">=25.2.0" }, +] + [[package]] name = "virtualenv" version = "20.30.0" From 7588fe34a1bb8811db9d611eefea3efbae1e50fd Mon Sep 17 00:00:00 2001 From: Abraham Lazaro Martinez Date: Fri, 2 May 2025 19:12:53 +0000 Subject: [PATCH 2/7] first version of samples --- py/plugins/vertex-ai/pyproject.toml | 3 + .../plugins/vertex_ai/models/retriever.py | 17 +- .../vertex_ai/vector_search/vector_search.py | 10 +- .../src/sample.py | 53 ++- .../src/setup_env.py | 307 ++++++++++++++++++ py/uv.lock | 8 +- 6 files changed, 380 insertions(+), 18 deletions(-) create mode 100644 py/samples/vertex-ai-vector-search-bigquery/src/setup_env.py diff --git a/py/plugins/vertex-ai/pyproject.toml b/py/plugins/vertex-ai/pyproject.toml index 425bee532e..ff2dd3c283 100644 --- a/py/plugins/vertex-ai/pyproject.toml +++ b/py/plugins/vertex-ai/pyproject.toml @@ -18,10 +18,13 @@ classifiers = [ ] dependencies = [ "genkit", + "google-genai>=1.7.0", "google-cloud-aiplatform>=1.77.0", "pytest-mock", "structlog>=25.2.0", "strenum>=0.4.15; python_version < '3.11'", + "google-cloud-bigquery", + "google-cloud-firestore", ] description = "Genkit Google Cloud Vertex AI Plugin" license = { text = "Apache-2.0" } diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py index 8828749359..242f1d07b3 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py @@ -19,6 +19,7 @@ from typing import Any import structlog +from google.auth.credentials import Credentials from google.cloud import aiplatform_v1, bigquery, firestore from google.cloud.aiplatform_v1 import FindNeighborsRequest, IndexDatapoint, Neighbor from pydantic import BaseModel, Field, ValidationError @@ -50,9 +51,10 @@ def __init__( self, ai: Genkit, name: str, - match_service_client: aiplatform_v1.MatchServiceAsyncClient, + match_service_client: aiplatform_v1.MatchServiceAsyncClient | None, embedder: str, embedder_options: dict[str, Any] | None = None, + credentials: Credentials | None = None, ) -> None: """Initializes the DocRetriever. @@ -61,6 +63,7 @@ def __init__( name: The name of this retriever instance. match_service_client: The Vertex AI Matching Engine client. embedder: The name of the embedder to use for generating embeddings. + Already added plugin prefix. embedder_options: Optional dictionary of options to pass to the embedder. """ self.ai = ai @@ -68,6 +71,7 @@ def __init__( self._match_service_client = match_service_client self.embedder = embedder self.embedder_options = embedder_options or {} + self.credentials = credentials async def retrieve(self, request: RetrieverRequest, _: ActionRunContext) -> RetrieverResponse: """Retrieves documents based on a given query. @@ -92,6 +96,7 @@ async def retrieve(self, request: RetrieverRequest, _: ActionRunContext) -> Retr else: top_k = 3 + logger.debug(f'top k neighbors: {top_k}') docs = await self._get_closest_documents( request=request, top_k=top_k, @@ -135,8 +140,16 @@ async def _get_closest_documents( ], ) - response = await self._match_service_client.find_neighbors(request=nn_request) + logger.debug('Before find neighbors') + match_service_client = self._match_service_client + if match_service_client is None: + match_service_client = aiplatform_v1.MatchServiceAsyncClient( + credentials=self.credentials, + ) + + response = await match_service_client.find_neighbors(request=nn_request) + logger.debug('After find neighbors') return await self._retrieve_neighbours_data_from_db(neighbours=response.nearest_neighbors[0].neighbors) @abstractmethod diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py index 006d37bb0f..5a97cd0dfc 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py @@ -90,9 +90,8 @@ def __init__( credentials=credentials, ) self._endpoint_client = aiplatform_v1.IndexEndpointServiceAsyncClient(credentials=credentials) - self._match_service_client = aiplatform_v1.MatchServiceAsyncClient( - credentials=credentials, - ) + self._match_service_client = None + self.credentials=credentials async def create_index( self, @@ -120,6 +119,7 @@ async def create_index( index.metadata = { 'config': index_config.model_dump(), 'contentsDeltaUri': contents_delta_uri, + 'index_update_method': 'STREAM_UPDATE' # TODO: Add the other 2 } request = aiplatform_v1.CreateIndexRequest( @@ -194,12 +194,14 @@ def initialize(self, ai: GenkitRegistry) -> None: Args: ai: The registry to register actions with. """ + logger.debug(f'Register retriever with {self.name} name') retriever = self.retriever_cls( ai=ai, name=self.name, match_service_client=self._match_service_client, - embedder=self.embedder, + embedder=vertexai_name(self.embedder), embedder_options=self.embedder_options, + credentials=self.credentials, **self.retriever_extra_args, ) diff --git a/py/samples/vertex-ai-vector-search-bigquery/src/sample.py b/py/samples/vertex-ai-vector-search-bigquery/src/sample.py index 71b7713496..e9bc8a6e32 100644 --- a/py/samples/vertex-ai-vector-search-bigquery/src/sample.py +++ b/py/samples/vertex-ai-vector-search-bigquery/src/sample.py @@ -17,11 +17,14 @@ import os import time +import structlog from google.cloud import aiplatform, bigquery from pydantic import BaseModel from genkit.ai import Genkit -from genkit.blocks.document import Document +from genkit.blocks.document import ( + Document, +) from genkit.plugins.vertex_ai import ( EmbeddingModels, VertexAI, @@ -32,17 +35,20 @@ LOCATION = os.getenv('LOCATION') PROJECT_ID = os.getenv('PROJECT_ID') -BIGQUERY_DATASET = os.getenv('BIGQUERY_DATASET') -BIGQUERY_TABLE = os.getenv('BIGQUERY_TABLE') +EMBEDDING_MODEL = EmbeddingModels.TEXT_EMBEDDING_004_ENG + +BIGQUERY_DATASET_NAME = os.getenv('BIGQUERY_DATASET_NAME') +BIGQUERY_TABLE_NAME = os.getenv('BIGQUERY_TABLE_NAME') + +VECTOR_SEARCH_INDEX_ID = os.getenv('VECTOR_SEARCH_INDEX_ID') + VECTOR_SEARCH_DEPLOYED_INDEX_ID = os.getenv('VECTOR_SEARCH_DEPLOYED_INDEX_ID') VECTOR_SEARCH_INDEX_ENDPOINT_ID = os.getenv('VECTOR_SEARCH_INDEX_ENDPOINT_ID') -VECTOR_SEARCH_INDEX_ID = os.getenv('VECTOR_SEARCH_INDEX_ID') VECTOR_SEARCH_PUBLIC_DOMAIN_NAME = os.getenv('VECTOR_SEARCH_PUBLIC_DOMAIN_NAME') - bq_client = bigquery.Client(project=PROJECT_ID) aiplatform.init(project=PROJECT_ID, location=LOCATION) - +logger = structlog.get_logger(__name__) ai = Genkit( plugins=[ @@ -51,22 +57,27 @@ retriever=BigQueryRetriever, retriever_extra_args={ 'bq_client': bq_client, - 'dataset_id': BIGQUERY_DATASET, - 'table_id': BIGQUERY_TABLE, + 'dataset_id': BIGQUERY_DATASET_NAME, + 'table_id': BIGQUERY_TABLE_NAME, + }, + embedder=EMBEDDING_MODEL, + embedder_options={ + 'task': 'RETRIEVAL_DOCUMENT', + 'output_dimensionality': 128, }, - embedder=EmbeddingModels.TEXT_EMBEDDING_004_ENG, - embedder_options={'taskType': 'RETRIEVAL_DOCUMENT'}, ), ] ) class QueryFlowInputSchema(BaseModel): + """Input schema.""" query: str k: int class QueryFlowOutputSchema(BaseModel): + """Output schema.""" result: list[dict] length: int time: int @@ -74,12 +85,18 @@ class QueryFlowOutputSchema(BaseModel): @ai.flow(name='queryFlow') async def query_flow(_input: QueryFlowInputSchema) -> QueryFlowOutputSchema: + """Executes a vector search with VertexAI Vector Search.""" start_time = time.time() query_document = Document.from_text(text=_input.query) + query_document.metadata = { + 'index_endpoint_path': 'projects/206382651113/locations/us-central1/indexEndpoints/8485065371965980672', + 'deployed_index_id': 'genkit_sample_1746207451742', + } result: list[Document] = await ai.retrieve( - retriever=vertexai_name(VECTOR_SEARCH_INDEX_ID), + retriever=vertexai_name('vertexAIVectorSearch'), query=query_document, + ) end_time = time.time() @@ -100,3 +117,17 @@ async def query_flow(_input: QueryFlowInputSchema) -> QueryFlowOutputSchema: length=len(result_data), time=duration, ) + + +async def main() -> None: + """Main function.""" + query_input = QueryFlowInputSchema( + query="Pedro", + k=10, + ) + + await logger.ainfo(await query_flow(query_input)) + + +if __name__ == '__main__': + ai.run_main(main()) diff --git a/py/samples/vertex-ai-vector-search-bigquery/src/setup_env.py b/py/samples/vertex-ai-vector-search-bigquery/src/setup_env.py new file mode 100644 index 0000000000..ca7eea03b5 --- /dev/null +++ b/py/samples/vertex-ai-vector-search-bigquery/src/setup_env.py @@ -0,0 +1,307 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import json +import os + +import structlog +from google.cloud import aiplatform, aiplatform_v1, bigquery + +from genkit import types +from genkit.ai import Genkit +from genkit.plugins.vertex_ai import ( + EmbeddingModels, + VertexAI, + VertexAIVectorSearch, + vertexai_name, +) +from genkit.plugins.vertex_ai.models.retriever import BigQueryRetriever +from genkit.plugins.vertex_ai.models.vectorstore import ( + DistanceMeasureType, + FeatureNormType, + IndexConfig, + IndexShardSize, +) + +# Environment Variables +LOCATION = os.getenv('LOCATION') +PROJECT_ID = os.getenv('PROJECT_ID') +EMBEDDING_MODEL = EmbeddingModels.TEXT_EMBEDDING_004_ENG + + +BIGQUERY_DATASET_NAME = os.getenv('BIGQUERY_DATASET_NAME') +BIGQUERY_TABLE_NAME = os.getenv('BIGQUERY_TABLE_NAME') + +VECTOR_SEARCH_INDEX_ID = os.getenv('VECTOR_SEARCH_INDEX_ID') + +VECTOR_SEARCH_DEPLOYED_INDEX_ID = os.getenv('VECTOR_SEARCH_DEPLOYED_INDEX_ID') +VECTOR_SEARCH_INDEX_ENDPOINT_ID = os.getenv('VECTOR_SEARCH_INDEX_ENDPOINT_ID') +VECTOR_SEARCH_PUBLIC_DOMAIN_NAME = os.getenv('VECTOR_SEARCH_PUBLIC_DOMAIN_NAME') + +# Initialize Clients +logger = structlog.get_logger(__name__) +bq_client = bigquery.Client(project=PROJECT_ID) +aiplatform.init(project=PROJECT_ID, location=LOCATION) + +# Configure Genkit +ai = Genkit( + plugins=[ + VertexAI(), + VertexAIVectorSearch( + retriever=BigQueryRetriever, + retriever_extra_args={ + 'bq_client': bq_client, + 'dataset_id': BIGQUERY_DATASET_NAME, + 'table_id': BIGQUERY_TABLE_NAME, + }, + embedder=EMBEDDING_MODEL, + embedder_options={'task': 'RETRIEVAL_DOCUMENT'}, + ) + ] +) + + +@ai.flow(name='generateEmbeddings') +async def generate_embeddings(): + """Generates document embeddings and upserts them to the Vertex AI Vector Search index. + + This flow retrieves data from BigQuery, generates embeddings for the documents, + and then upserts these embeddings to the specified Vector Search index. + """ + toy_documents = [ + { + "id": "doc1", + "content": {"title": "Document 1", "body": "This is the content of document 1."}, + "metadata": {"author": "Alice", "date": "2024-01-15"}, + }, + { + "id": "doc2", + "content": {"title": "Document 2", "body": "This is the content of document 2."}, + "metadata": {"author": "Bob", "date": "2024-02-20"}, + }, + { + "id": "doc3", + "content": {"title": "Document 3", "body": "Content for doc 3"}, + "metadata": {"author": "Charlie", "date": "2024-03-01"}, + }, + ] + + create_bigquery_dataset_and_table( + PROJECT_ID, LOCATION, BIGQUERY_DATASET_NAME, BIGQUERY_TABLE_NAME, toy_documents, + ) + + results_dict = get_data_from_bigquery( + bq_client=bq_client, + project_id=PROJECT_ID, + dataset_id=BIGQUERY_DATASET_NAME, + table_id=BIGQUERY_TABLE_NAME, + ) + + genkit_documents = [ + types.Document(content=[types.TextPart(text=text)]) + for text in results_dict.values() + ] + + embed_response = await ai.embed( + embedder=vertexai_name(EMBEDDING_MODEL), + documents=genkit_documents, + options={'task': 'RETRIEVAL_DOCUMENT', 'output_dimensionality': 128}, + ) + + embeddings = [emb.embedding for emb in embed_response.embeddings] + logger.debug(f'Generated {len(embeddings)} embeddings, dimension: {len(embeddings[0])}') + + ids = list(results_dict.keys())[:len(embeddings)] + data_embeddings = list(zip(ids, embeddings, strict=True)) + + upsert_data = [(str(id), embedding) for id, embedding in data_embeddings] + upsert_index(PROJECT_ID, LOCATION, VECTOR_SEARCH_INDEX_ID, upsert_data) + + +def create_bigquery_dataset_and_table( + project_id: str, + location: str, + dataset_id: str, + table_id: str, + documents: list[dict[str, str]], +) -> None: + """Creates a BigQuery dataset and table, and inserts documents. + + Args: + project_id: The ID of the Google Cloud project. + location: The location for the BigQuery resources. + dataset_id: The ID of the BigQuery dataset. + table_id: The ID of the BigQuery table. + documents: A list of dictionaries, where each dictionary represents a document + with 'id', 'content', and 'metadata' keys. 'content' and 'metadata' + are expected to be JSON serializable. + """ + client = bigquery.Client(project=project_id) + dataset_ref = bigquery.DatasetReference(project_id, dataset_id) + dataset = bigquery.Dataset(dataset_ref) + dataset.location = location + + try: + dataset = client.create_dataset(dataset, exists_ok=True) + logger.debug(f"Dataset {client.project}.{dataset.dataset_id} created.") + except Exception as e: + logger.exception(f"Error creating dataset: {e}") + raise e + + schema = [ + bigquery.SchemaField("id", "STRING", mode="REQUIRED"), + bigquery.SchemaField("content", "JSON"), + bigquery.SchemaField("metadata", "JSON"), + ] + + table_ref = dataset_ref.table(table_id) + table = bigquery.Table(table_ref, schema=schema) + try: + table = client.create_table(table, exists_ok=True) + logger.debug(f"Table {table.project}.{table.dataset_id}.{table.table_id} created.") + except Exception as e: + logger.exception(f"Error creating table: {e}") + raise e + + rows_to_insert = [ + { + "id": doc["id"], + "content": json.dumps(doc["content"]), + "metadata": json.dumps(doc["metadata"]), + } + for doc in documents + ] + + errors = client.insert_rows_json(table, rows_to_insert) + if errors: + logger.error(f"Errors inserting rows: {errors}") + raise Exception(f"Failed to insert rows: {errors}") + else: + logger.debug(f"Inserted {len(rows_to_insert)} rows into BigQuery.") + + +def get_data_from_bigquery( + bq_client: bigquery.Client, + project_id: str, + dataset_id: str, + table_id: str, +) -> dict[str, str]: + """Retrieves data from a BigQuery table. + + Args: + bq_client: The BigQuery client. + project_id: The ID of the Google Cloud project. + dataset_id: The ID of the BigQuery dataset. + table_id: The ID of the BigQuery table. + + Returns: + A dictionary where keys are document IDs and values are JSON strings + representing the document content. + """ + table_ref = bigquery.TableReference.from_string( + f"{project_id}.{dataset_id}.{table_id}" + ) + query = f"SELECT id, content FROM `{table_ref}`" + query_job = bq_client.query(query) + rows = query_job.result() + + results = {row['id']: json.dumps(row['content']) for row in rows} + logger.debug(f'Found {len(results)} rows with different ids into BigQuery.') + + return results + + +async def create_and_deploy_index( + vector_search_index_id: str, + vector_search_deployed_index_id: str, + vector_search_index_endpoint_id: str, +) -> None: + """Creates and deploys a Vertex AI Vector Search Index. + + Args: + vector_search_index_id: The ID of the Vector Search index to create. + vector_search_deployed_index_id: The ID of the deployed index. + vector_search_index_endpoint_id: The ID of the Vector Search index endpoint. + """ + vertex_ai_vector_search = VertexAIVectorSearch() + + logger.debug('Creating VertexAI Vector Search Index') + await vertex_ai_vector_search.create_index( + display_name=vector_search_index_id, + description='Toy index for genkit sample', + index_config=IndexConfig( + dimensions=128, + approximate_neighbors_count=100, + distance_measure_type=DistanceMeasureType.COSINE, + feature_norm_type=FeatureNormType.UNIT_L2_NORMALIZED, + shard_size=IndexShardSize.MEDIUM, + ), + ) + + logger.debug('Deploying VertexAI Vector Search Index') + await vertex_ai_vector_search.deploy_index( + index_name=vector_search_deployed_index_id, + endpoint_name=vector_search_index_endpoint_id, + ) + + +def upsert_index( + project_id: str, + region: str, + index_name: str, + data: list[tuple[str, list[float]]], +) -> None: + """Upserts data points to a Vertex AI Index using batch processing. + + Args: + project_id: The ID of your Google Cloud project. + region: The region where the Index is located. + index_name: The name of the Vertex AI Index. + data: A list of tuples, where each tuple contains (id, embedding). + id should be a string, and embedding should be a list of floats. + """ + aiplatform.init(project=project_id, location=region) + + index_client = aiplatform_v1.IndexServiceClient( + client_options={"api_endpoint": f"{region}-aiplatform.googleapis.com"} + ) + + index_path = index_client.index_path( + project=project_id, location=region, index=index_name + ) + + datapoints = [ + aiplatform_v1.IndexDatapoint(datapoint_id=id, feature_vector=embedding) + for id, embedding in data + ] + + logger.debug(f'Attempting to insert {len(datapoints)} rows into Index {index_path}') + + upsert_request = aiplatform_v1.UpsertDatapointsRequest( + index=index_path, datapoints=datapoints + ) + + response = index_client.upsert_datapoints(request=upsert_request) + logger.info(f"Upserted {len(datapoints)} datapoints. Response: {response}") + + +async def main() -> None: + """Main function.""" + await logger.ainfo(await generate_embeddings()) + + +if __name__ == '__main__': + ai.run_main(main()) diff --git a/py/uv.lock b/py/uv.lock index 07b6084c0d..776dfe058b 100644 --- a/py/uv.lock +++ b/py/uv.lock @@ -963,7 +963,7 @@ wheels = [ [[package]] name = "genkit" -version = "0.3.2" +version = "0.3.1" source = { editable = "packages/genkit" } dependencies = [ { name = "anyio" }, @@ -1160,6 +1160,9 @@ source = { editable = "plugins/vertex-ai" } dependencies = [ { name = "genkit" }, { name = "google-cloud-aiplatform" }, + { name = "google-cloud-bigquery" }, + { name = "google-cloud-firestore" }, + { name = "google-genai" }, { name = "pytest-mock" }, { name = "strenum", marker = "python_full_version < '3.11'" }, { name = "structlog" }, @@ -1169,6 +1172,9 @@ dependencies = [ requires-dist = [ { name = "genkit", editable = "packages/genkit" }, { name = "google-cloud-aiplatform", specifier = ">=1.77.0" }, + { name = "google-cloud-bigquery" }, + { name = "google-cloud-firestore" }, + { name = "google-genai", specifier = ">=1.7.0" }, { name = "pytest-mock" }, { name = "strenum", marker = "python_full_version < '3.11'", specifier = ">=0.4.15" }, { name = "structlog", specifier = ">=25.2.0" }, From 2b223b4aff0165fc53405fa18c6c0cfb6c704611 Mon Sep 17 00:00:00 2001 From: Abraham Lazaro Martinez Date: Tue, 6 May 2025 14:05:31 +0000 Subject: [PATCH 3/7] fix: vertex ai bigquery vector search --- .../plugins/vertex_ai/models/retriever.py | 76 +++++++++++-------- .../vertex_ai/vector_search/vector_search.py | 15 ++-- .../README.md | 15 ++-- .../src/sample.py | 21 ++--- 4 files changed, 72 insertions(+), 55 deletions(-) diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py index 242f1d07b3..50a4628317 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py @@ -16,11 +16,11 @@ import json from abc import ABC, abstractmethod +from collections.abc import Callable from typing import Any import structlog -from google.auth.credentials import Credentials -from google.cloud import aiplatform_v1, bigquery, firestore +from google.cloud import bigquery, firestore from google.cloud.aiplatform_v1 import FindNeighborsRequest, IndexDatapoint, Neighbor from pydantic import BaseModel, Field, ValidationError @@ -51,27 +51,28 @@ def __init__( self, ai: Genkit, name: str, - match_service_client: aiplatform_v1.MatchServiceAsyncClient | None, + match_service_client_generator: Callable, embedder: str, embedder_options: dict[str, Any] | None = None, - credentials: Credentials | None = None, + limit: int | None = None, ) -> None: """Initializes the DocRetriever. Args: ai: The Genkit application instance. name: The name of this retriever instance. - match_service_client: The Vertex AI Matching Engine client. + match_service_client_generator: The Vertex AI Matching Engine client. embedder: The name of the embedder to use for generating embeddings. Already added plugin prefix. embedder_options: Optional dictionary of options to pass to the embedder. + limit: Optional limit of neighbors to find. """ self.ai = ai self.name = name - self._match_service_client = match_service_client self.embedder = embedder self.embedder_options = embedder_options or {} - self.credentials = credentials + self._match_service_client_generator = match_service_client_generator + self.limit = limit or 3 async def retrieve(self, request: RetrieverRequest, _: ActionRunContext) -> RetrieverResponse: """Retrieves documents based on a given query. @@ -85,25 +86,25 @@ async def retrieve(self, request: RetrieverRequest, _: ActionRunContext) -> Retr """ document = Document.from_document_data(document_data=request.query) + # Removing limit key from embedder options + # TODO: Think a better patter of usage + custom_embedder_options = self.embedder_options.copy() + if 'limit' in custom_embedder_options.keys(): + del custom_embedder_options['limit'] + embeddings = await self.ai.embed( embedder=self.embedder, documents=[document], - options=self.embedder_options, + options=custom_embedder_options, ) - if self.embedder_options: - top_k = self.embedder_options.get('limit') or 3 - else: - top_k = 3 - - logger.debug(f'top k neighbors: {top_k}') docs = await self._get_closest_documents( request=request, - top_k=top_k, + top_k=self.limit, query_embeddings=embeddings.embeddings[0], ) - return RetrieverResponse(documents=[d.document for d in docs]) + return RetrieverResponse(documents=docs) async def _get_closest_documents( self, request: RetrieverRequest, top_k: int, query_embeddings: Embedding @@ -123,12 +124,21 @@ async def _get_closest_documents( index endpoint path in its metadata. """ metadata = request.query.metadata - if not metadata or 'index_endpoint_path' not in metadata: + if not metadata or 'index_endpoint_path' not in metadata or 'api_endpoint' not in metadata: raise AttributeError('Request provides no data about index endpoint path') + api_endpoint = metadata['api_endpoint'] index_endpoint_path = metadata['index_endpoint_path'] deployed_index_id = metadata['deployed_index_id'] + client_options = { + "api_endpoint": api_endpoint + } + + vector_search_client = self._match_service_client_generator( + client_options=client_options, + ) + nn_request = FindNeighborsRequest( index_endpoint=index_endpoint_path, deployed_index_id=deployed_index_id, @@ -140,16 +150,8 @@ async def _get_closest_documents( ], ) - logger.debug('Before find neighbors') - - match_service_client = self._match_service_client - if match_service_client is None: - match_service_client = aiplatform_v1.MatchServiceAsyncClient( - credentials=self.credentials, - ) + response = await vector_search_client.find_neighbors(request=nn_request) - response = await match_service_client.find_neighbors(request=nn_request) - logger.debug('After find neighbors') return await self._retrieve_neighbours_data_from_db(neighbours=response.nearest_neighbors[0].neighbors) @abstractmethod @@ -220,6 +222,12 @@ async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> if n.datapoint and n.datapoint.datapoint_id ] + distance_by_id = { + n.datapoint.datapoint_id: n.distance + for n in neighbours + if n.datapoint and n.datapoint.datapoint_id + } + if not ids: return [] @@ -243,13 +251,17 @@ async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> for row in rows: try: - doc_data = { - 'content': json.loads(row['content']), - } - if row.get('metadata'): - doc_data['metadata'] = json.loads(row['metadata']) + id = row['id'] + + content = row['content'] + content = content if isinstance(content, str) else json.dumps(row['content']) + + metadata = row.get('metadata', {}) + metadata = metadata if isinstance(metadata, dict) else json.loads(metadata) + metadata['id'] = id + metadata['distance'] = distance_by_id[id] - documents.append(Document(**doc_data)) + documents.append(Document.from_text(content, metadata)) except (ValidationError, json.JSONDecodeError, Exception) as error: doc_id = row.get('id', '') await logger.awarning(f'Failed to parse document data for document with ID {doc_id}: {error}') diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py index 5a97cd0dfc..bbf35a0668 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py @@ -14,7 +14,7 @@ # # SPDX-License-Identifier: Apache-2.0 -import os +from functools import partial from typing import Any import structlog @@ -90,8 +90,11 @@ def __init__( credentials=credentials, ) self._endpoint_client = aiplatform_v1.IndexEndpointServiceAsyncClient(credentials=credentials) - self._match_service_client = None - self.credentials=credentials + + self._match_service_client_generator = partial( + aiplatform_v1.MatchServiceAsyncClient, + credentials=credentials, + ) async def create_index( self, @@ -194,14 +197,12 @@ def initialize(self, ai: GenkitRegistry) -> None: Args: ai: The registry to register actions with. """ - logger.debug(f'Register retriever with {self.name} name') retriever = self.retriever_cls( ai=ai, name=self.name, - match_service_client=self._match_service_client, - embedder=vertexai_name(self.embedder), + match_service_client_generator=self._match_service_client_generator, + embedder=self.embedder, embedder_options=self.embedder_options, - credentials=self.credentials, **self.retriever_extra_args, ) diff --git a/py/samples/vertex-ai-vector-search-bigquery/README.md b/py/samples/vertex-ai-vector-search-bigquery/README.md index 47d03ab707..43f9dec108 100644 --- a/py/samples/vertex-ai-vector-search-bigquery/README.md +++ b/py/samples/vertex-ai-vector-search-bigquery/README.md @@ -7,18 +7,17 @@ An example demonstrating the use Vector Search API with BigQuery retriever for V 1. Install [GCP CLI](https://cloud.google.com/sdk/docs/install). 2. Run the following code to connect to VertexAI. ```bash -gcloud auth application-default login` +gcloud auth application-default login ``` 3. Set the following env vars to run the sample ``` export LOCATION='' export PROJECT_ID='' -export BIGQUERY_DATASET='' -export BIGQUERY_TABLE='' +export BIGQUERY_DATASET_NAME='' +export BIGQUERY_TABLE_NAME='' export VECTOR_SEARCH_DEPLOYED_INDEX_ID='' -export VECTOR_SEARCH_INDEX_ENDPOINT_ID='' -export VECTOR_SEARCH_INDEX_ID='' -export VECTOR_SEARCH_PUBLIC_DOMAIN_NAME='' +export VECTOR_SEARCH_INDEX_ENDPOINT_PATH='' +export VECTOR_SEARCH_API_ENDPOINT='' ``` 4. Run the sample. @@ -27,3 +26,7 @@ export VECTOR_SEARCH_PUBLIC_DOMAIN_NAME='' ```bash genkit start -- uv run src/sample.py ``` + +## Set up env for sample +In the file `setup_env.py` you will find some code that will help you to create the bigquery dataset, table with the expected schema, encode the content of the table and push this to the VertexAI Vector Search index. +This index must be created with update method set as `stream`. You will also find some code to create and deploy the index in this file. diff --git a/py/samples/vertex-ai-vector-search-bigquery/src/sample.py b/py/samples/vertex-ai-vector-search-bigquery/src/sample.py index e9bc8a6e32..7bbafe2a1c 100644 --- a/py/samples/vertex-ai-vector-search-bigquery/src/sample.py +++ b/py/samples/vertex-ai-vector-search-bigquery/src/sample.py @@ -40,14 +40,13 @@ BIGQUERY_DATASET_NAME = os.getenv('BIGQUERY_DATASET_NAME') BIGQUERY_TABLE_NAME = os.getenv('BIGQUERY_TABLE_NAME') -VECTOR_SEARCH_INDEX_ID = os.getenv('VECTOR_SEARCH_INDEX_ID') - VECTOR_SEARCH_DEPLOYED_INDEX_ID = os.getenv('VECTOR_SEARCH_DEPLOYED_INDEX_ID') -VECTOR_SEARCH_INDEX_ENDPOINT_ID = os.getenv('VECTOR_SEARCH_INDEX_ENDPOINT_ID') -VECTOR_SEARCH_PUBLIC_DOMAIN_NAME = os.getenv('VECTOR_SEARCH_PUBLIC_DOMAIN_NAME') +VECTOR_SEARCH_INDEX_ENDPOINT_PATH = os.getenv('VECVECTOR_SEARCH_INDEX_ENDPOINT_PATHTOR_SEARCH_INDEX_ENDPOINT_ID') +VECTOR_SEARCH_API_ENDPOINT = os.getenv('VECTOR_SEARCH_API_ENDPOINT') bq_client = bigquery.Client(project=PROJECT_ID) aiplatform.init(project=PROJECT_ID, location=LOCATION) + logger = structlog.get_logger(__name__) ai = Genkit( @@ -60,10 +59,11 @@ 'dataset_id': BIGQUERY_DATASET_NAME, 'table_id': BIGQUERY_TABLE_NAME, }, - embedder=EMBEDDING_MODEL, + embedder=vertexai_name(EMBEDDING_MODEL), embedder_options={ 'task': 'RETRIEVAL_DOCUMENT', 'output_dimensionality': 128, + 'limit': 10, }, ), ] @@ -89,14 +89,14 @@ async def query_flow(_input: QueryFlowInputSchema) -> QueryFlowOutputSchema: start_time = time.time() query_document = Document.from_text(text=_input.query) query_document.metadata = { - 'index_endpoint_path': 'projects/206382651113/locations/us-central1/indexEndpoints/8485065371965980672', - 'deployed_index_id': 'genkit_sample_1746207451742', + 'api_endpoint': VECTOR_SEARCH_API_ENDPOINT, + 'index_endpoint_path': VECTOR_SEARCH_INDEX_ENDPOINT_PATH, + 'deployed_index_id': VECTOR_SEARCH_DEPLOYED_INDEX_ID, } result: list[Document] = await ai.retrieve( retriever=vertexai_name('vertexAIVectorSearch'), query=query_document, - ) end_time = time.time() @@ -104,8 +104,9 @@ async def query_flow(_input: QueryFlowInputSchema) -> QueryFlowOutputSchema: duration = int(end_time - start_time) result_data = [] - for doc in result: + for doc in result.documents: result_data.append({ + 'id': doc.metadata.get('id'), 'text': doc.content[0].root.text, 'distance': doc.metadata.get('distance'), }) @@ -122,7 +123,7 @@ async def query_flow(_input: QueryFlowInputSchema) -> QueryFlowOutputSchema: async def main() -> None: """Main function.""" query_input = QueryFlowInputSchema( - query="Pedro", + query="Content for doc", k=10, ) From 70e722a8117294135c68b5dfce6dfa35805b5813 Mon Sep 17 00:00:00 2001 From: Abraham Lazaro Martinez Date: Tue, 6 May 2025 14:15:43 +0000 Subject: [PATCH 4/7] fix: add limit to options --- .../plugins/vertex_ai/models/retriever.py | 19 ++++++++----------- .../src/sample.py | 6 +++++- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py index 50a4628317..cbb4c172ec 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py @@ -31,6 +31,8 @@ logger = structlog.get_logger(__name__) +DEFAULT_LIMIT_NEIGHBORS: int = 3 + class DocRetriever(ABC): """Abstract base class for Vertex AI Vector Search document retrieval. @@ -54,7 +56,6 @@ def __init__( match_service_client_generator: Callable, embedder: str, embedder_options: dict[str, Any] | None = None, - limit: int | None = None, ) -> None: """Initializes the DocRetriever. @@ -65,14 +66,12 @@ def __init__( embedder: The name of the embedder to use for generating embeddings. Already added plugin prefix. embedder_options: Optional dictionary of options to pass to the embedder. - limit: Optional limit of neighbors to find. """ self.ai = ai self.name = name self.embedder = embedder self.embedder_options = embedder_options or {} self._match_service_client_generator = match_service_client_generator - self.limit = limit or 3 async def retrieve(self, request: RetrieverRequest, _: ActionRunContext) -> RetrieverResponse: """Retrieves documents based on a given query. @@ -86,21 +85,19 @@ async def retrieve(self, request: RetrieverRequest, _: ActionRunContext) -> Retr """ document = Document.from_document_data(document_data=request.query) - # Removing limit key from embedder options - # TODO: Think a better patter of usage - custom_embedder_options = self.embedder_options.copy() - if 'limit' in custom_embedder_options.keys(): - del custom_embedder_options['limit'] - embeddings = await self.ai.embed( embedder=self.embedder, documents=[document], - options=custom_embedder_options, + options=self.embedder_options, ) + limit_neighbors = DEFAULT_LIMIT_NEIGHBORS + if isinstance(request.options, dict) and request.options.get('limit') is not None: + limit_neighbors = request.options.get('limit') + docs = await self._get_closest_documents( request=request, - top_k=self.limit, + top_k=limit_neighbors, query_embeddings=embeddings.embeddings[0], ) diff --git a/py/samples/vertex-ai-vector-search-bigquery/src/sample.py b/py/samples/vertex-ai-vector-search-bigquery/src/sample.py index 7bbafe2a1c..b7cea3f7e6 100644 --- a/py/samples/vertex-ai-vector-search-bigquery/src/sample.py +++ b/py/samples/vertex-ai-vector-search-bigquery/src/sample.py @@ -63,7 +63,6 @@ embedder_options={ 'task': 'RETRIEVAL_DOCUMENT', 'output_dimensionality': 128, - 'limit': 10, }, ), ] @@ -94,9 +93,14 @@ async def query_flow(_input: QueryFlowInputSchema) -> QueryFlowOutputSchema: 'deployed_index_id': VECTOR_SEARCH_DEPLOYED_INDEX_ID, } + options = { + 'limit': 10, + } + result: list[Document] = await ai.retrieve( retriever=vertexai_name('vertexAIVectorSearch'), query=query_document, + options=options, ) end_time = time.time() From ba1a1a11f3c37e9c46990be4f6fbd1e381a64615 Mon Sep 17 00:00:00 2001 From: Abraham Lazaro Martinez Date: Tue, 6 May 2025 15:35:35 +0000 Subject: [PATCH 5/7] drop unused code --- .../plugins/vertex_ai/models/vectorstore.py | 57 -------- .../vertex_ai/vector_search/vector_search.py | 137 +----------------- .../README.md | 2 +- .../src/sample.py | 2 +- .../src/setup_env.py | 50 +------ 5 files changed, 6 insertions(+), 242 deletions(-) delete mode 100644 py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/vectorstore.py diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/vectorstore.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/vectorstore.py deleted file mode 100644 index 6be441f7a3..0000000000 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/vectorstore.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-License-Identifier: Apache-2.0 - -import sys - -if sys.version_info < (3, 11): # noqa - from strenum import StrEnum # noqa -else: # noqa - from enum import StrEnum # noqa - -from pydantic import BaseModel, Field - - -class IndexShardSize(StrEnum): - """Defines the size of each shard in the index.""" - SMALL = 'SHARD_SIZE_SMALL' - MEDIUM = 'SHARD_SIZE_MEDIUM' - LARGE = 'SHARD_SIZE_LARGE' - - -class FeatureNormType(StrEnum): - """Specifies the normalization applied to feature vectors.""" - NONE = 'NONE' - UNIT_L2_NORMALIZED = 'UNIT_L2_NORM' - - -class DistanceMeasureType(StrEnum): - """Defines the available distance measure methods.""" - SQUARED_L2 = 'SQUARED_L2_DISTANCE' - L2 = 'L2_DISTANCE' - COSINE = 'COSINE_DISTANCE' - DOT_PRODUCT = 'DOT_PRODUCT_DISTANCE' - - -class IndexConfig(BaseModel): - """Defines the configurations of indexes.""" - dimensions: int = 128 - approximate_neighbors_count: int = Field(default=100, alias='approximateNeighborsCount') - distance_measure_type: DistanceMeasureType | str = Field( - default=DistanceMeasureType.COSINE, alias='distanceMeasureType' - ) - feature_norm_type: FeatureNormType | str = Field(default=FeatureNormType.NONE, alias='featureNormType') - shard_size: IndexShardSize | str = Field(default=IndexShardSize.MEDIUM, alias='shardSize') - algorithm_config: dict | None = Field(default=None, alias='algorithmConfig') diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py index bbf35a0668..7126ca65b4 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py @@ -19,16 +19,14 @@ import structlog from google.auth.credentials import Credentials -from google.cloud import aiplatform_v1, storage -from google.genai.types import HttpOptions, HttpOptionsDict, Operation +from google.cloud import aiplatform_v1 -from genkit.ai import GENKIT_CLIENT_HEADER, GenkitRegistry, Plugin +from genkit.ai import GenkitRegistry, Plugin from genkit.plugins.vertex_ai import vertexai_name from genkit.plugins.vertex_ai.models.retriever import ( DocRetriever, RetrieverOptionsSchema, ) -from genkit.plugins.vertex_ai.models.vectorstore import IndexConfig logger = structlog.get_logger(__name__) @@ -51,7 +49,6 @@ def __init__( location: str | None = 'us-central1', embedder: str | None = None, embedder_options: dict[str, Any] | None = None, - http_options: HttpOptions | HttpOptionsDict | None = None, ) -> None: """Initializes the VertexAIVectorSearch plugin. @@ -68,10 +65,7 @@ def __init__( embedder: Optional identifier for the embedding model to use. embedder_options: Optional dictionary of options to pass to the embedding model. - http_options: Optional HTTP options for API requests. """ - http_options = _inject_attribution_headers(http_options=http_options) - self.project = project self.location = location @@ -81,114 +75,11 @@ def __init__( self.retriever_cls = retriever self.retriever_extra_args = retriever_extra_args or {} - self._storage_client = storage.Client( - project=self.project, - credentials=credentials, - extra_headers=http_options.headers, - ) - self._index_client = aiplatform_v1.IndexServiceAsyncClient( - credentials=credentials, - ) - self._endpoint_client = aiplatform_v1.IndexEndpointServiceAsyncClient(credentials=credentials) - self._match_service_client_generator = partial( aiplatform_v1.MatchServiceAsyncClient, credentials=credentials, ) - async def create_index( - self, - display_name: str, - description: str | None, - index_config: IndexConfig | None = None, - contents_delta_uri: str | None = None, - ) -> None: - """Creates a Vertex AI Vector Search index. - - Args: - display_name: The display name for the index. - description: Optional description of the index. - index_config: Optional configuration for the index. If not provided, a - default configuration is used. - contents_delta_uri: Optional URI of the Cloud Storage location for the - contents delta. - """ - if not index_config: - index_config = IndexConfig() - - index = aiplatform_v1.Index() - index.display_name = display_name - index.description = description - index.metadata = { - 'config': index_config.model_dump(), - 'contentsDeltaUri': contents_delta_uri, - 'index_update_method': 'STREAM_UPDATE' # TODO: Add the other 2 - } - - request = aiplatform_v1.CreateIndexRequest( - parent=self.index_location_path, - index=index, - ) - - operation = await self._index_client.create_index(request=request) - - logger.debug(await operation.result()) - - async def deploy_index(self, index_name: str, endpoint_name: str) -> None: - """Deploys an index to an endpoint. - - Args: - index_name: The name of the index to deploy. - endpoint_name: The name of the endpoint to deploy the index to. - """ - deployed_index = aiplatform_v1.DeployedIndex() - deployed_index.id = index_name - deployed_index.index = self.get_index_path(index_name=index_name) - - request = aiplatform_v1.DeployIndexRequest( - index_endpoint=endpoint_name, - deployed_index=deployed_index, - ) - - operation = self._endpoint_client.deploy_index(request=request) - - logger.debug(await operation.result()) - - def upload_jsonl_file(self, local_path: str, bucket_name: str, destination_location: str) -> Operation: - """Uploads a JSONL file to Cloud Storage. - - Args: - local_path: The local path to the JSONL file. - bucket_name: The name of the Cloud Storage bucket. - destination_location: The destination path within the bucket. - - Returns: - The upload operation. - """ - bucket = self._storage_client.bucket(bucket_name=bucket_name) - blob = bucket.blob(destination_location) - blob.upload_from_filename(local_path) - - def get_index_path(self, index_name: str) -> str: - """Gets the full resource path of an index. - - Args: - index_name: The name of the index. - - Returns: - The full resource path of the index. - """ - return self._index_client.index_path(project=self.project, location=self.location, index=index_name) - - @property - def index_location_path(self) -> str: - """Gets the resource path of the index location. - - Returns: - The resource path of the index location. - """ - return self._index_client.common_location_path(project=self.project, location=self.location) - def initialize(self, ai: GenkitRegistry) -> None: """Initialize plugin with the retriver specified. @@ -211,27 +102,3 @@ def initialize(self, ai: GenkitRegistry) -> None: config_schema=RetrieverOptionsSchema, fn=retriever.retrieve, ) - - -def _inject_attribution_headers(http_options) -> HttpOptions: - """Adds genkit client info to the appropriate http headers.""" - if not http_options: - http_options = HttpOptions() - else: - if isinstance(http_options, dict): - http_options = HttpOptions(**http_options) - - if not http_options.headers: - http_options.headers = {} - - if 'x-goog-api-client' not in http_options.headers: - http_options.headers['x-goog-api-client'] = GENKIT_CLIENT_HEADER - else: - http_options.headers['x-goog-api-client'] += f' {GENKIT_CLIENT_HEADER}' - - if 'user-agent' not in http_options.headers: - http_options.headers['user-agent'] = GENKIT_CLIENT_HEADER - else: - http_options.headers['user-agent'] += f' {GENKIT_CLIENT_HEADER}' - - return http_options diff --git a/py/samples/vertex-ai-vector-search-bigquery/README.md b/py/samples/vertex-ai-vector-search-bigquery/README.md index 43f9dec108..c7a43ca694 100644 --- a/py/samples/vertex-ai-vector-search-bigquery/README.md +++ b/py/samples/vertex-ai-vector-search-bigquery/README.md @@ -29,4 +29,4 @@ genkit start -- uv run src/sample.py ## Set up env for sample In the file `setup_env.py` you will find some code that will help you to create the bigquery dataset, table with the expected schema, encode the content of the table and push this to the VertexAI Vector Search index. -This index must be created with update method set as `stream`. You will also find some code to create and deploy the index in this file. +This index must be created with update method set as `stream`. VertexAI Index is expected to be already created. diff --git a/py/samples/vertex-ai-vector-search-bigquery/src/sample.py b/py/samples/vertex-ai-vector-search-bigquery/src/sample.py index b7cea3f7e6..6f8038b069 100644 --- a/py/samples/vertex-ai-vector-search-bigquery/src/sample.py +++ b/py/samples/vertex-ai-vector-search-bigquery/src/sample.py @@ -128,7 +128,7 @@ async def main() -> None: """Main function.""" query_input = QueryFlowInputSchema( query="Content for doc", - k=10, + k=3, ) await logger.ainfo(await query_flow(query_input)) diff --git a/py/samples/vertex-ai-vector-search-bigquery/src/setup_env.py b/py/samples/vertex-ai-vector-search-bigquery/src/setup_env.py index ca7eea03b5..7fc6710c9b 100644 --- a/py/samples/vertex-ai-vector-search-bigquery/src/setup_env.py +++ b/py/samples/vertex-ai-vector-search-bigquery/src/setup_env.py @@ -29,34 +29,22 @@ vertexai_name, ) from genkit.plugins.vertex_ai.models.retriever import BigQueryRetriever -from genkit.plugins.vertex_ai.models.vectorstore import ( - DistanceMeasureType, - FeatureNormType, - IndexConfig, - IndexShardSize, -) # Environment Variables LOCATION = os.getenv('LOCATION') PROJECT_ID = os.getenv('PROJECT_ID') EMBEDDING_MODEL = EmbeddingModels.TEXT_EMBEDDING_004_ENG - BIGQUERY_DATASET_NAME = os.getenv('BIGQUERY_DATASET_NAME') BIGQUERY_TABLE_NAME = os.getenv('BIGQUERY_TABLE_NAME') VECTOR_SEARCH_INDEX_ID = os.getenv('VECTOR_SEARCH_INDEX_ID') -VECTOR_SEARCH_DEPLOYED_INDEX_ID = os.getenv('VECTOR_SEARCH_DEPLOYED_INDEX_ID') -VECTOR_SEARCH_INDEX_ENDPOINT_ID = os.getenv('VECTOR_SEARCH_INDEX_ENDPOINT_ID') -VECTOR_SEARCH_PUBLIC_DOMAIN_NAME = os.getenv('VECTOR_SEARCH_PUBLIC_DOMAIN_NAME') - -# Initialize Clients -logger = structlog.get_logger(__name__) bq_client = bigquery.Client(project=PROJECT_ID) aiplatform.init(project=PROJECT_ID, location=LOCATION) -# Configure Genkit +logger = structlog.get_logger(__name__) + ai = Genkit( plugins=[ VertexAI(), @@ -224,40 +212,6 @@ def get_data_from_bigquery( return results -async def create_and_deploy_index( - vector_search_index_id: str, - vector_search_deployed_index_id: str, - vector_search_index_endpoint_id: str, -) -> None: - """Creates and deploys a Vertex AI Vector Search Index. - - Args: - vector_search_index_id: The ID of the Vector Search index to create. - vector_search_deployed_index_id: The ID of the deployed index. - vector_search_index_endpoint_id: The ID of the Vector Search index endpoint. - """ - vertex_ai_vector_search = VertexAIVectorSearch() - - logger.debug('Creating VertexAI Vector Search Index') - await vertex_ai_vector_search.create_index( - display_name=vector_search_index_id, - description='Toy index for genkit sample', - index_config=IndexConfig( - dimensions=128, - approximate_neighbors_count=100, - distance_measure_type=DistanceMeasureType.COSINE, - feature_norm_type=FeatureNormType.UNIT_L2_NORMALIZED, - shard_size=IndexShardSize.MEDIUM, - ), - ) - - logger.debug('Deploying VertexAI Vector Search Index') - await vertex_ai_vector_search.deploy_index( - index_name=vector_search_deployed_index_id, - endpoint_name=vector_search_index_endpoint_id, - ) - - def upsert_index( project_id: str, region: str, From 9c9bae15eec340e28e2392eff07a88a17f81923f Mon Sep 17 00:00:00 2001 From: Abraham Lazaro Martinez Date: Tue, 6 May 2025 16:09:41 +0000 Subject: [PATCH 6/7] fix: firestore sample --- .../plugins/vertex_ai/models/retriever.py | 16 +++++- .../src/sample.py | 1 + .../src/sample.py | 50 ++++++++++++++++--- 3 files changed, 58 insertions(+), 9 deletions(-) diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py index cbb4c172ec..62684ee6e3 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py @@ -310,13 +310,25 @@ async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> for neighbor in neighbours: doc_ref = self.db.collection(self.collection_name).document(document_id=neighbor.datapoint.datapoint_id) - doc_snapshot = await doc_ref.get() + doc_snapshot = doc_ref.get() if doc_snapshot.exists: doc_data = doc_snapshot.to_dict() or {} + content = doc_data.get('content') + content = json.dumps(content) if isinstance(content, dict) else str(content) + + metadata = doc_data.get('metadata', {}) + metadata['id'] = neighbor.datapoint.datapoint_id + metadata['distance'] = neighbor.distance + try: - documents.append(Document(**doc_data)) + documents.append( + Document.from_text( + content, + metadata, + ) + ) except ValidationError as e: await logger.awarning( f'Failed to parse document data for ID {neighbor.datapoint.datapoint_id}: {e}' diff --git a/py/samples/vertex-ai-vector-search-bigquery/src/sample.py b/py/samples/vertex-ai-vector-search-bigquery/src/sample.py index 6f8038b069..39c994a261 100644 --- a/py/samples/vertex-ai-vector-search-bigquery/src/sample.py +++ b/py/samples/vertex-ai-vector-search-bigquery/src/sample.py @@ -86,6 +86,7 @@ class QueryFlowOutputSchema(BaseModel): async def query_flow(_input: QueryFlowInputSchema) -> QueryFlowOutputSchema: """Executes a vector search with VertexAI Vector Search.""" start_time = time.time() + query_document = Document.from_text(text=_input.query) query_document.metadata = { 'api_endpoint': VECTOR_SEARCH_API_ENDPOINT, diff --git a/py/samples/vertex-ai-vector-search-firestore/src/sample.py b/py/samples/vertex-ai-vector-search-firestore/src/sample.py index b34fdba291..a8bd67f563 100644 --- a/py/samples/vertex-ai-vector-search-firestore/src/sample.py +++ b/py/samples/vertex-ai-vector-search-firestore/src/sample.py @@ -17,6 +17,7 @@ import os import time +import structlog from google.cloud import aiplatform, firestore from pydantic import BaseModel @@ -32,15 +33,18 @@ LOCATION = os.getenv('LOCATION') PROJECT_ID = os.getenv('PROJECT_ID') +EMBEDDING_MODEL = EmbeddingModels.TEXT_EMBEDDING_004_ENG + FIRESTORE_COLLECTION = os.getenv('FIRESTORE_COLLECTION') + VECTOR_SEARCH_DEPLOYED_INDEX_ID = os.getenv('VECTOR_SEARCH_DEPLOYED_INDEX_ID') -VECTOR_SEARCH_INDEX_ENDPOINT_ID = os.getenv('VECTOR_SEARCH_INDEX_ENDPOINT_ID') -VECTOR_SEARCH_INDEX_ID = os.getenv('VECTOR_SEARCH_INDEX_ID') -VECTOR_SEARCH_PUBLIC_DOMAIN_NAME = os.getenv('VECTOR_SEARCH_PUBLIC_DOMAIN_NAME') +VECTOR_SEARCH_INDEX_ENDPOINT_PATH = os.getenv('VECVECTOR_SEARCH_INDEX_ENDPOINT_PATHTOR_SEARCH_INDEX_ENDPOINT_ID') +VECTOR_SEARCH_API_ENDPOINT = os.getenv('VECTOR_SEARCH_API_ENDPOINT') firestore_client = firestore.Client(project=PROJECT_ID) aiplatform.init(project=PROJECT_ID, location=LOCATION) +logger = structlog.get_logger(__name__) ai = Genkit( plugins=[ @@ -51,19 +55,24 @@ 'firestore_client': firestore_client, 'collection_name': FIRESTORE_COLLECTION, }, - embedder=EmbeddingModels.TEXT_EMBEDDING_004_ENG, - embedder_options={'taskType': 'RETRIEVAL_DOCUMENT'}, + embedder=vertexai_name(EMBEDDING_MODEL), + embedder_options={ + 'task': 'RETRIEVAL_DOCUMENT', + 'output_dimensionality': 128, + }, ), ] ) class QueryFlowInputSchema(BaseModel): + """Input schema.""" query: str k: int class QueryFlowOutputSchema(BaseModel): + """Output schema.""" result: list[dict] length: int time: int @@ -71,12 +80,24 @@ class QueryFlowOutputSchema(BaseModel): @ai.flow(name='queryFlow') async def query_flow(_input: QueryFlowInputSchema) -> QueryFlowOutputSchema: + """Executes a vector search with VertexAI Vector Search.""" start_time = time.time() + query_document = Document.from_text(text=_input.query) + query_document.metadata = { + 'api_endpoint': VECTOR_SEARCH_API_ENDPOINT, + 'index_endpoint_path': VECTOR_SEARCH_INDEX_ENDPOINT_PATH, + 'deployed_index_id': VECTOR_SEARCH_DEPLOYED_INDEX_ID, + } + + options = { + 'limit': 10, + } result: list[Document] = await ai.retrieve( - retriever=vertexai_name(VECTOR_SEARCH_INDEX_ID), + retriever=vertexai_name('vertexAIVectorSearch'), query=query_document, + options=options, ) end_time = time.time() @@ -84,8 +105,9 @@ async def query_flow(_input: QueryFlowInputSchema) -> QueryFlowOutputSchema: duration = int(end_time - start_time) result_data = [] - for doc in result: + for doc in result.documents: result_data.append({ + 'id': doc.metadata.get('id'), 'text': doc.content[0].root.text, 'distance': doc.metadata.get('distance'), }) @@ -97,3 +119,17 @@ async def query_flow(_input: QueryFlowInputSchema) -> QueryFlowOutputSchema: length=len(result_data), time=duration, ) + + +async def main() -> None: + """Main function.""" + query_input = QueryFlowInputSchema( + query="Content for doc", + k=3, + ) + + await logger.ainfo(await query_flow(query_input)) + + +if __name__ == '__main__': + ai.run_main(main()) From 62367c957e3a857fecc9eba81126533069ded87d Mon Sep 17 00:00:00 2001 From: Abraham Lazaro Martinez Date: Tue, 6 May 2025 16:13:24 +0000 Subject: [PATCH 7/7] fix: readme --- .../src/genkit/plugins/vertex_ai/models/retriever.py | 3 +-- py/samples/vertex-ai-vector-search-firestore/README.md | 9 ++++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py index 62684ee6e3..23e6644d95 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/retriever.py @@ -251,10 +251,9 @@ async def _retrieve_neighbours_data_from_db(self, neighbours: list[Neighbor]) -> id = row['id'] content = row['content'] - content = content if isinstance(content, str) else json.dumps(row['content']) + content = json.dumps(content) if isinstance(content, dict) else str(content) metadata = row.get('metadata', {}) - metadata = metadata if isinstance(metadata, dict) else json.loads(metadata) metadata['id'] = id metadata['distance'] = distance_by_id[id] diff --git a/py/samples/vertex-ai-vector-search-firestore/README.md b/py/samples/vertex-ai-vector-search-firestore/README.md index 69299cc371..591c705d9a 100644 --- a/py/samples/vertex-ai-vector-search-firestore/README.md +++ b/py/samples/vertex-ai-vector-search-firestore/README.md @@ -1,4 +1,4 @@ -# Vertex AI Vector Search Firestore +# Vertex AI - Vector Search Firestore An example demonstrating the use Vector Search API with Firestore retriever for Vertex AI @@ -7,7 +7,7 @@ An example demonstrating the use Vector Search API with Firestore retriever for 1. Install [GCP CLI](https://cloud.google.com/sdk/docs/install). 2. Run the following code to connect to VertexAI. ```bash -gcloud auth application-default login` +gcloud auth application-default login ``` 3. Set the following env vars to run the sample ``` @@ -15,9 +15,8 @@ export LOCATION='' export PROJECT_ID='' export FIRESTORE_COLLECTION='' export VECTOR_SEARCH_DEPLOYED_INDEX_ID='' -export VECTOR_SEARCH_INDEX_ENDPOINT_ID='' -export VECTOR_SEARCH_INDEX_ID='' -export VECTOR_SEARCH_PUBLIC_DOMAIN_NAME='' +export VECTOR_SEARCH_INDEX_ENDPOINT_PATH='' +export VECTOR_SEARCH_API_ENDPOINT='' ``` 4. Run the sample.