From 79e7dcf98fb549ee3a771a0f65121a67300d73d4 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 19 Feb 2025 13:03:32 -0800 Subject: [PATCH] feat(drivers-vector): support upserting/querying `ImageArtifact`s feat(drivers-embedding): support embedding `ImageArtifact`s --- .../amazon_bedrock_cohere_embedding_driver.py | 4 +- .../amazon_bedrock_titan_embedding_driver.py | 4 +- ...on_sagemaker_jumpstart_embedding_driver.py | 4 +- .../embedding/base_embedding_driver.py | 19 +++- .../embedding/cohere_embedding_driver.py | 4 +- .../embedding/dummy_embedding_driver.py | 2 +- .../embedding/google_embedding_driver.py | 4 +- .../huggingface_hub_embedding_driver.py | 4 +- .../embedding/ollama_embedding_driver.py | 4 +- .../embedding/openai_embedding_driver.py | 4 +- .../embedding/voyageai_embedding_driver.py | 4 +- .../azure_mongodb_vector_store_driver.py | 20 ---- .../vector/base_vector_store_driver.py | 91 ++++++++++++++----- .../vector/dummy_vector_store_driver.py | 3 +- .../griptape_cloud_vector_store_driver.py | 5 +- .../vector/marqo_vector_store_driver.py | 71 +++++++++++---- .../mongodb_atlas_vector_store_driver.py | 20 ---- .../vector/opensearch_vector_store_driver.py | 29 ------ .../vector/pgvector_vector_store_driver.py | 21 ----- .../vector/pinecone_vector_store_driver.py | 20 ---- 20 files changed, 170 insertions(+), 167 deletions(-) diff --git a/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py b/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py index 9614e3597..491c4d818 100644 --- a/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py +++ b/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py @@ -46,7 +46,9 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver): def client(self) -> BedrockClient: return self.session.client("bedrock-runtime") - def try_embed_chunk(self, chunk: str) -> list[float]: + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: + if isinstance(chunk, bytes): + raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.") payload = {"input_type": self.input_type, "texts": [chunk]} response = self.client.invoke_model( diff --git a/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py b/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py index c615358d2..cbc12ab0e 100644 --- a/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py +++ b/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py @@ -42,7 +42,9 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver): def client(self) -> BedrockClient: return self.session.client("bedrock-runtime") - def try_embed_chunk(self, chunk: str) -> list[float]: + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: + if isinstance(chunk, bytes): + raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.") payload = {"inputText": chunk} response = self.client.invoke_model( diff --git a/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py b/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py index 9a9a66493..9651ba4c4 100644 --- a/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py +++ b/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py @@ -26,7 +26,9 @@ class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver): def client(self) -> SageMakerClient: return self.session.client("sagemaker-runtime") - def try_embed_chunk(self, chunk: str) -> list[float]: + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: + if isinstance(chunk, bytes): + raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.") payload = {"text_inputs": chunk, "mode": "embedding"} endpoint_response = self.client.invoke_endpoint( diff --git a/griptape/drivers/embedding/base_embedding_driver.py b/griptape/drivers/embedding/base_embedding_driver.py index cc7cffe96..a460a49db 100644 --- a/griptape/drivers/embedding/base_embedding_driver.py +++ b/griptape/drivers/embedding/base_embedding_driver.py @@ -6,12 +6,12 @@ import numpy as np from attrs import define, field +from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact from griptape.chunkers import BaseChunker, TextChunker from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin from griptape.mixins.serializable_mixin import SerializableMixin if TYPE_CHECKING: - from griptape.artifacts import TextArtifact from griptape.tokenizers import BaseTokenizer @@ -31,6 +31,21 @@ class BaseEmbeddingDriver(SerializableMixin, ExponentialBackoffMixin, ABC): def __attrs_post_init__(self) -> None: self.chunker = TextChunker(tokenizer=self.tokenizer) if self.tokenizer else None + def embed_artifact( + self, artifact: TextArtifact | ImageArtifact | ListArtifact[TextArtifact | ImageArtifact] + ) -> list[float]: + if isinstance(artifact, TextArtifact): + return self.embed_text_artifact(artifact) + elif isinstance(artifact, ImageArtifact): + return self.embed_image_artifact(artifact) + else: + embeddings = [self.embed_artifact(artifact) for artifact in artifact] + + return np.average(embeddings, axis=0).tolist() + + def embed_image_artifact(self, artifact: ImageArtifact) -> list[float]: + return self.try_embed_chunk(artifact.value) + def embed_text_artifact(self, artifact: TextArtifact) -> list[float]: return self.embed_string(artifact.to_text()) @@ -46,7 +61,7 @@ def embed_string(self, string: str) -> list[float]: raise RuntimeError("Failed to embed string.") @abstractmethod - def try_embed_chunk(self, chunk: str) -> list[float]: ... + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: ... def _embed_long_string(self, string: str) -> list[float]: """Embeds a string that is too long to embed in one go. diff --git a/griptape/drivers/embedding/cohere_embedding_driver.py b/griptape/drivers/embedding/cohere_embedding_driver.py index b6dc69261..e226a93bb 100644 --- a/griptape/drivers/embedding/cohere_embedding_driver.py +++ b/griptape/drivers/embedding/cohere_embedding_driver.py @@ -39,7 +39,9 @@ class CohereEmbeddingDriver(BaseEmbeddingDriver): def client(self) -> Client: return import_optional_dependency("cohere").Client(self.api_key) - def try_embed_chunk(self, chunk: str) -> list[float]: + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: + if isinstance(chunk, bytes): + raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.") result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type) if isinstance(result.embeddings, list): diff --git a/griptape/drivers/embedding/dummy_embedding_driver.py b/griptape/drivers/embedding/dummy_embedding_driver.py index 1ad79ee03..5a9dc510d 100644 --- a/griptape/drivers/embedding/dummy_embedding_driver.py +++ b/griptape/drivers/embedding/dummy_embedding_driver.py @@ -10,5 +10,5 @@ class DummyEmbeddingDriver(BaseEmbeddingDriver): model: None = field(init=False, default=None, kw_only=True) - def try_embed_chunk(self, chunk: str) -> list[float]: + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: raise DummyError(__class__.__name__, "try_embed_chunk") diff --git a/griptape/drivers/embedding/google_embedding_driver.py b/griptape/drivers/embedding/google_embedding_driver.py index 6524bef62..2c52dddcc 100644 --- a/griptape/drivers/embedding/google_embedding_driver.py +++ b/griptape/drivers/embedding/google_embedding_driver.py @@ -26,7 +26,9 @@ class GoogleEmbeddingDriver(BaseEmbeddingDriver): task_type: str = field(default="retrieval_document", kw_only=True, metadata={"serializable": True}) title: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - def try_embed_chunk(self, chunk: str) -> list[float]: + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: + if isinstance(chunk, bytes): + raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.") genai = import_optional_dependency("google.generativeai") genai.configure(api_key=self.api_key) diff --git a/griptape/drivers/embedding/huggingface_hub_embedding_driver.py b/griptape/drivers/embedding/huggingface_hub_embedding_driver.py index 154d23ef0..8a8fc20d1 100644 --- a/griptape/drivers/embedding/huggingface_hub_embedding_driver.py +++ b/griptape/drivers/embedding/huggingface_hub_embedding_driver.py @@ -32,7 +32,9 @@ def client(self) -> InferenceClient: token=self.api_token, ) - def try_embed_chunk(self, chunk: str) -> list[float]: + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: + if isinstance(chunk, bytes): + raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.") response = self.client.feature_extraction(chunk) return response.flatten().tolist() diff --git a/griptape/drivers/embedding/ollama_embedding_driver.py b/griptape/drivers/embedding/ollama_embedding_driver.py index 7dde53547..9addf6356 100644 --- a/griptape/drivers/embedding/ollama_embedding_driver.py +++ b/griptape/drivers/embedding/ollama_embedding_driver.py @@ -30,5 +30,7 @@ class OllamaEmbeddingDriver(BaseEmbeddingDriver): def client(self) -> Client: return import_optional_dependency("ollama").Client(host=self.host) - def try_embed_chunk(self, chunk: str) -> list[float]: + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: + if isinstance(chunk, bytes): + raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.") return list(self.client.embeddings(model=self.model, prompt=chunk)["embedding"]) diff --git a/griptape/drivers/embedding/openai_embedding_driver.py b/griptape/drivers/embedding/openai_embedding_driver.py index 17fbc9377..9fc74abd4 100644 --- a/griptape/drivers/embedding/openai_embedding_driver.py +++ b/griptape/drivers/embedding/openai_embedding_driver.py @@ -44,7 +44,9 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver): def client(self) -> openai.OpenAI: return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization) - def try_embed_chunk(self, chunk: str) -> list[float]: + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: + if isinstance(chunk, bytes): + raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.") # Address a performance issue in older ada models # https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 if self.model.endswith("001"): diff --git a/griptape/drivers/embedding/voyageai_embedding_driver.py b/griptape/drivers/embedding/voyageai_embedding_driver.py index 30a009cc3..e8fe3abdf 100644 --- a/griptape/drivers/embedding/voyageai_embedding_driver.py +++ b/griptape/drivers/embedding/voyageai_embedding_driver.py @@ -40,5 +40,7 @@ class VoyageAiEmbeddingDriver(BaseEmbeddingDriver): def client(self) -> Any: return import_optional_dependency("voyageai").Client(api_key=self.api_key) - def try_embed_chunk(self, chunk: str) -> list[float]: + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: + if isinstance(chunk, bytes): + raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.") return self.client.embed([chunk], model=self.model, input_type=self.input_type).embeddings[0] diff --git a/griptape/drivers/vector/azure_mongodb_vector_store_driver.py b/griptape/drivers/vector/azure_mongodb_vector_store_driver.py index 44964265d..25ab4788e 100644 --- a/griptape/drivers/vector/azure_mongodb_vector_store_driver.py +++ b/griptape/drivers/vector/azure_mongodb_vector_store_driver.py @@ -61,23 +61,3 @@ def query_vector( ) for doc in collection.aggregate(pipeline) ] - - def query( - self, - query: str, - *, - count: Optional[int] = None, - namespace: Optional[str] = None, - include_vectors: bool = False, - offset: Optional[int] = None, - **kwargs, - ) -> list[BaseVectorStoreDriver.Entry]: - """Queries the MongoDB collection for documents that match the provided query string. - - Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index. - """ - # Using the embedding driver to convert the query string into a vector - vector = self.embedding_driver.embed_string(query) - return self.query_vector( - vector, count=count, namespace=namespace, include_vectors=include_vectors, offset=offset, **kwargs - ) diff --git a/griptape/drivers/vector/base_vector_store_driver.py b/griptape/drivers/vector/base_vector_store_driver.py index 8f63ceac0..27f61227e 100644 --- a/griptape/drivers/vector/base_vector_store_driver.py +++ b/griptape/drivers/vector/base_vector_store_driver.py @@ -2,13 +2,13 @@ import uuid from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional, overload from attrs import define, field from griptape import utils from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact +from griptape.artifacts.image_artifact import ImageArtifact from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin from griptape.mixins.serializable_mixin import SerializableMixin from griptape.utils import with_contextvars @@ -36,18 +36,65 @@ def to_artifact(self) -> BaseArtifact: def upsert_text_artifacts( self, - artifacts: list[TextArtifact] | dict[str, list[TextArtifact]], + artifacts: list[TextArtifact] | dict[str, list[TextArtifact] | list[ImageArtifact]], *, meta: Optional[dict] = None, **kwargs, ) -> list[str] | dict[str, list[str]]: + return self.upsert_collection(artifacts, meta=meta, **kwargs) + + def upsert_text_artifact( + self, + artifact: TextArtifact, + *, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + vector_id: Optional[str] = None, + **kwargs, + ) -> str: + return self.upsert(artifact, namespace=namespace, meta=meta, vector_id=vector_id, **kwargs) + + def upsert_text( + self, + string: str, + *, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + vector_id: Optional[str] = None, + **kwargs, + ) -> str: + return self.upsert(string, namespace=namespace, meta=meta, vector_id=vector_id, **kwargs) + + @overload + def upsert_collection( + self, + artifacts: list[TextArtifact], + *, + meta: Optional[dict] = None, + **kwargs, + ) -> list[str]: ... + + @overload + def upsert_collection( + self, + artifacts: dict[str, list[TextArtifact] | list[ImageArtifact]], + *, + meta: Optional[dict] = None, + **kwargs, + ) -> dict[str, list[str]]: ... + + def upsert_collection( + self, + artifacts: list[TextArtifact] | dict[str, list[TextArtifact] | list[ImageArtifact]], + *, + meta: Optional[dict] = None, + **kwargs, + ): with self.create_futures_executor() as futures_executor: if isinstance(artifacts, list): return utils.execute_futures_list( [ - futures_executor.submit( - with_contextvars(self.upsert_text_artifact), a, namespace=None, meta=meta, **kwargs - ) + futures_executor.submit(with_contextvars(self.upsert), a, namespace=None, meta=meta, **kwargs) for a in artifacts ], ) @@ -61,21 +108,23 @@ def upsert_text_artifacts( futures_dict[namespace].append( futures_executor.submit( - with_contextvars(self.upsert_text_artifact), a, namespace=namespace, meta=meta, **kwargs + with_contextvars(self.upsert), a, namespace=namespace, meta=meta, **kwargs ) ) return utils.execute_futures_list_dict(futures_dict) - def upsert_text_artifact( + def upsert( self, - artifact: TextArtifact, + value: str | TextArtifact | ImageArtifact, *, namespace: Optional[str] = None, meta: Optional[dict] = None, vector_id: Optional[str] = None, **kwargs, ) -> str: + artifact = TextArtifact(value) if isinstance(value, str) else value + meta = {} if meta is None else meta if vector_id is None: @@ -87,23 +136,10 @@ def upsert_text_artifact( else: meta["artifact"] = artifact.to_json() - vector = artifact.embedding or artifact.generate_embedding(self.embedding_driver) + vector = self.embedding_driver.embed_artifact(artifact) return self.upsert_vector(vector, vector_id=vector_id, namespace=namespace, meta=meta, **kwargs) - def upsert_text( - self, - string: str, - *, - namespace: Optional[str] = None, - meta: Optional[dict] = None, - vector_id: Optional[str] = None, - **kwargs, - ) -> str: - return self.upsert_text_artifact( - TextArtifact(string), vector_id=vector_id, namespace=namespace, meta=meta, **kwargs - ) - def does_entry_exist(self, vector_id: str, *, namespace: Optional[str] = None) -> bool: try: return self.load_entry(vector_id, namespace=namespace) is not None @@ -150,14 +186,19 @@ def query_vector( def query( self, - query: str, + query: str | BaseArtifact, *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, **kwargs, ) -> list[Entry]: - vector = self.embedding_driver.embed_string(query) + if isinstance(query, str): + vector = self.embedding_driver.embed_string(query) + elif isinstance(query, (TextArtifact, ImageArtifact, ListArtifact)): + vector = self.embedding_driver.embed_artifact(query) + else: + raise ValueError(f"Unsupported query type: {type(query)}") return self.query_vector(vector, count=count, namespace=namespace, include_vectors=include_vectors, **kwargs) def _get_default_vector_id(self, value: str) -> str: diff --git a/griptape/drivers/vector/dummy_vector_store_driver.py b/griptape/drivers/vector/dummy_vector_store_driver.py index 0f927fd12..b2dc0292a 100644 --- a/griptape/drivers/vector/dummy_vector_store_driver.py +++ b/griptape/drivers/vector/dummy_vector_store_driver.py @@ -9,6 +9,7 @@ from griptape.exceptions import DummyError if TYPE_CHECKING: + from griptape.artifacts import BaseArtifact from griptape.drivers.embedding import BaseEmbeddingDriver @@ -52,7 +53,7 @@ def query_vector( def query( self, - query: str, + query: str | BaseArtifact, *, count: Optional[int] = None, namespace: Optional[str] = None, diff --git a/griptape/drivers/vector/griptape_cloud_vector_store_driver.py b/griptape/drivers/vector/griptape_cloud_vector_store_driver.py index 12ba6aa9d..01ff993d2 100644 --- a/griptape/drivers/vector/griptape_cloud_vector_store_driver.py +++ b/griptape/drivers/vector/griptape_cloud_vector_store_driver.py @@ -7,6 +7,7 @@ import requests from attrs import Factory, define, field +from griptape.artifacts import BaseArtifact from griptape.drivers.embedding.dummy import DummyEmbeddingDriver from griptape.drivers.vector import BaseVectorStoreDriver @@ -83,7 +84,7 @@ def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact: def query( self, - query: str, + query: str | BaseArtifact, *, count: Optional[int] = None, namespace: Optional[str] = None, @@ -97,6 +98,8 @@ def query( Performs a query on the Knowledge Base and returns Artifacts with close vector proximity to the query, optionally filtering to only those that match the provided filter(s). """ + if isinstance(query, BaseArtifact): + raise ValueError(f"{self.__class__.__name__} does not support querying with Artifacts.") url = urljoin(self.base_url.strip("/"), f"/api/knowledge-bases/{self.knowledge_base_id}/query") query_args = { diff --git a/griptape/drivers/vector/marqo_vector_store_driver.py b/griptape/drivers/vector/marqo_vector_store_driver.py index 07d717480..9be749ade 100644 --- a/griptape/drivers/vector/marqo_vector_store_driver.py +++ b/griptape/drivers/vector/marqo_vector_store_driver.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: import marqo - from griptape.artifacts import TextArtifact + from griptape.artifacts import BaseArtifact, TextArtifact @define @@ -165,9 +165,42 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto return entries + def query_vector( + self, + vector: list[float], + *, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + include_metadata: bool = True, + **kwargs, + ) -> list[BaseVectorStoreDriver.Entry]: + """Query the Marqo index for documents. + + Args: + vector: The vector to query by. + count: The maximum number of results to return. + namespace: The namespace to filter results by. + include_vectors: Whether to include vector data in the results. + include_metadata: Whether to include metadata in the results. + kwargs: Additional keyword arguments to pass to the Marqo client. + + Returns: + The list of query results. + """ + params = { + "limit": count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, + "attributes_to_retrieve": None if include_metadata else ["_id"], + "filter_string": f"namespace:{namespace}" if namespace else None, + } | kwargs + + results = self.client.index(self.index).search(**params, context={"tensor": [vector], "weight": 1}) + + return self.__process_results(results, include_vectors=include_vectors) + def query( self, - query: str, + query: str | BaseArtifact, *, count: Optional[int] = None, namespace: Optional[str] = None, @@ -195,22 +228,7 @@ def query( } | kwargs results = self.client.index(self.index).search(query, **params) - - if include_vectors: - results["hits"] = [ - {**r, **self.client.index(self.index).get_document(r["_id"], expose_facets=True)} - for r in results["hits"] - ] - - return [ - BaseVectorStoreDriver.Entry( - id=r["_id"], - vector=r["_tensor_facets"][0]["_embedding"] if include_vectors else [], - score=r["_score"], - meta={k: v for k, v in r.items() if k not in ["_score", "_tensor_facets"]}, - ) - for r in results["hits"] - ] + return self.__process_results(results, include_vectors=include_vectors) def delete_index(self, name: str) -> dict[str, Any]: """Delete an index in the Marqo client. @@ -256,3 +274,20 @@ def upsert_vector( def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.") + + def __process_results(self, results: dict, *, include_vectors: bool) -> list[BaseVectorStoreDriver.Entry]: + if include_vectors: + results["hits"] = [ + {**r, **self.client.index(self.index).get_document(r["_id"], expose_facets=True)} + for r in results["hits"] + ] + + return [ + BaseVectorStoreDriver.Entry( + id=r["_id"], + vector=r["_tensor_facets"][0]["_embedding"] if include_vectors else [], + score=r["_score"], + meta={k: v for k, v in r.items() if k not in ["_score", "_tensor_facets"]}, + ) + for r in results["hits"] + ] diff --git a/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py b/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py index d82f8f768..783725912 100644 --- a/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py +++ b/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py @@ -168,26 +168,6 @@ def query_vector( for doc in collection.aggregate(pipeline) ] - def query( - self, - query: str, - *, - count: Optional[int] = None, - namespace: Optional[str] = None, - include_vectors: bool = False, - offset: Optional[int] = None, - **kwargs, - ) -> list[BaseVectorStoreDriver.Entry]: - """Queries the MongoDB collection for documents that match the provided query string. - - Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index. - """ - # Using the embedding driver to convert the query string into a vector - vector = self.embedding_driver.embed_string(query) - return self.query_vector( - vector, count=count, namespace=namespace, include_vectors=include_vectors, offset=offset, **kwargs - ) - def delete_vector(self, vector_id: str) -> None: """Deletes the vector from the collection.""" collection = self.get_collection() diff --git a/griptape/drivers/vector/opensearch_vector_store_driver.py b/griptape/drivers/vector/opensearch_vector_store_driver.py index 23ecf600f..ade185b0a 100644 --- a/griptape/drivers/vector/opensearch_vector_store_driver.py +++ b/griptape/drivers/vector/opensearch_vector_store_driver.py @@ -164,34 +164,5 @@ def query_vector( for hit in response["hits"]["hits"] ] - def query( - self, - query: str, - *, - count: Optional[int] = None, - namespace: Optional[str] = None, - include_vectors: bool = False, - include_metadata: bool = True, - field_name: str = "vector", - **kwargs, - ) -> list[BaseVectorStoreDriver.Entry]: - """Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided query string. - - Results can be limited using the count parameter and optionally filtered by a namespace. - - Returns: - A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace. - """ - vector = self.embedding_driver.embed_string(query) - return self.query_vector( - vector, - count=count, - namespace=namespace, - include_vectors=include_vectors, - include_metadata=include_metadata, - field_name=field_name, - **kwargs, - ) - def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.") diff --git a/griptape/drivers/vector/pgvector_vector_store_driver.py b/griptape/drivers/vector/pgvector_vector_store_driver.py index ebbdbacde..e3fa9012a 100644 --- a/griptape/drivers/vector/pgvector_vector_store_driver.py +++ b/griptape/drivers/vector/pgvector_vector_store_driver.py @@ -180,27 +180,6 @@ def query_vector( for result in results ] - def query( - self, - query: str, - *, - count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, - namespace: Optional[str] = None, - include_vectors: bool = False, - distance_metric: str = "cosine_distance", - **kwargs, - ) -> list[BaseVectorStoreDriver.Entry]: - """Performs a search on the collection to find vectors similar to the provided input vector, optionally filtering to only those that match the provided namespace.""" - vector = self.embedding_driver.embed_string(query) - return self.query_vector( - vector, - count=count, - namespace=namespace, - include_vectors=include_vectors, - distance_metric=distance_metric, - **kwargs, - ) - def default_vector_model(self) -> Any: pgvector_sqlalchemy = import_optional_dependency("pgvector.sqlalchemy") sqlalchemy = import_optional_dependency("sqlalchemy") diff --git a/griptape/drivers/vector/pinecone_vector_store_driver.py b/griptape/drivers/vector/pinecone_vector_store_driver.py index c6c95d842..d64cceca2 100644 --- a/griptape/drivers/vector/pinecone_vector_store_driver.py +++ b/griptape/drivers/vector/pinecone_vector_store_driver.py @@ -117,25 +117,5 @@ def query_vector( for r in results["matches"] ] - def query( - self, - query: str, - *, - count: Optional[int] = None, - namespace: Optional[str] = None, - include_vectors: bool = False, - include_metadata: bool = True, - **kwargs, - ) -> list[BaseVectorStoreDriver.Entry]: - vector = self.embedding_driver.embed_string(query) - return self.query_vector( - vector, - count=count, - namespace=namespace, - include_vectors=include_vectors, - include_metadata=include_metadata, - **kwargs, - ) - def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")