From c4e288766e064dd2ab41c99fe26d1a9f2da9ca7d Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 19 Feb 2025 13:03:32 -0800 Subject: [PATCH] feat(drivers-vector): support upserting/querying `ImageArtifact`s feat(drivers-embedding): support embedding `ImageArtifact`s --- .../amazon_bedrock_cohere_embedding_driver.py | 4 +- .../amazon_bedrock_titan_embedding_driver.py | 4 +- ...on_sagemaker_jumpstart_embedding_driver.py | 4 +- .../embedding/base_embedding_driver.py | 19 +++- .../embedding/cohere_embedding_driver.py | 4 +- .../embedding/dummy_embedding_driver.py | 2 +- .../embedding/google_embedding_driver.py | 4 +- .../huggingface_hub_embedding_driver.py | 4 +- .../embedding/ollama_embedding_driver.py | 4 +- .../embedding/openai_embedding_driver.py | 4 +- .../embedding/voyageai_embedding_driver.py | 4 +- .../azure_mongodb_vector_store_driver.py | 20 ---- .../vector/base_vector_store_driver.py | 91 ++++++++++++++----- .../vector/dummy_vector_store_driver.py | 3 +- .../griptape_cloud_vector_store_driver.py | 5 +- .../vector/marqo_vector_store_driver.py | 71 +++++++++++---- .../mongodb_atlas_vector_store_driver.py | 20 ---- .../vector/opensearch_vector_store_driver.py | 29 ------ .../vector/pgvector_vector_store_driver.py | 21 ----- .../vector/pinecone_vector_store_driver.py | 20 ---- ..._amazon_bedrock_cohere_embedding_driver.py | 17 +++- ...t_amazon_bedrock_titan_embedding_driver.py | 17 +++- .../test_azure_openai_embedding_driver.py | 17 +++- .../embedding/test_cohere_embedding_driver.py | 24 ++++- .../embedding/test_google_embedding_driver.py | 17 +++- .../test_hugging_face_hub_embedding_driver.py | 38 ++++++++ .../embedding/test_ollama_embedding_driver.py | 18 +++- .../embedding/test_openai_embedding_driver.py | 13 ++- ...st_sagemaker_jumpstart_embedding_driver.py | 12 +++ .../test_voyageai_embedding_driver.py | 17 +++- ...test_griptape_cloud_vector_store_driver.py | 7 ++ .../vector/test_marqo_vector_store_driver.py | 11 ++- 32 files changed, 357 insertions(+), 188 deletions(-) create mode 100644 tests/unit/drivers/embedding/test_hugging_face_hub_embedding_driver.py diff --git a/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py b/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py index 9614e3597..491c4d818 100644 --- a/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py +++ b/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py @@ -46,7 +46,9 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver): def client(self) -> BedrockClient: return self.session.client("bedrock-runtime") - def try_embed_chunk(self, chunk: str) -> list[float]: + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: + if isinstance(chunk, bytes): + raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.") payload = {"input_type": self.input_type, "texts": [chunk]} response = self.client.invoke_model( diff --git a/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py b/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py index c615358d2..cbc12ab0e 100644 --- a/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py +++ b/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py @@ -42,7 +42,9 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver): def client(self) -> BedrockClient: return self.session.client("bedrock-runtime") - def try_embed_chunk(self, chunk: str) -> list[float]: + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: + if isinstance(chunk, bytes): + raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.") payload = {"inputText": chunk} response = self.client.invoke_model( diff --git a/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py b/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py index 9a9a66493..9651ba4c4 100644 --- a/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py +++ b/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py @@ -26,7 +26,9 @@ class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver): def client(self) -> SageMakerClient: return self.session.client("sagemaker-runtime") - def try_embed_chunk(self, chunk: str) -> list[float]: + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: + if isinstance(chunk, bytes): + raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.") payload = {"text_inputs": chunk, "mode": "embedding"} endpoint_response = self.client.invoke_endpoint( diff --git a/griptape/drivers/embedding/base_embedding_driver.py b/griptape/drivers/embedding/base_embedding_driver.py index cc7cffe96..a460a49db 100644 --- a/griptape/drivers/embedding/base_embedding_driver.py +++ b/griptape/drivers/embedding/base_embedding_driver.py @@ -6,12 +6,12 @@ import numpy as np from attrs import define, field +from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact from griptape.chunkers import BaseChunker, TextChunker from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin from griptape.mixins.serializable_mixin import SerializableMixin if TYPE_CHECKING: - from griptape.artifacts import TextArtifact from griptape.tokenizers import BaseTokenizer @@ -31,6 +31,21 @@ class BaseEmbeddingDriver(SerializableMixin, ExponentialBackoffMixin, ABC): def __attrs_post_init__(self) -> None: self.chunker = TextChunker(tokenizer=self.tokenizer) if self.tokenizer else None + def embed_artifact( + self, artifact: TextArtifact | ImageArtifact | ListArtifact[TextArtifact | ImageArtifact] + ) -> list[float]: + if isinstance(artifact, TextArtifact): + return self.embed_text_artifact(artifact) + elif isinstance(artifact, ImageArtifact): + return self.embed_image_artifact(artifact) + else: + embeddings = [self.embed_artifact(artifact) for artifact in artifact] + + return np.average(embeddings, axis=0).tolist() + + def embed_image_artifact(self, artifact: ImageArtifact) -> list[float]: + return self.try_embed_chunk(artifact.value) + def embed_text_artifact(self, artifact: TextArtifact) -> list[float]: return self.embed_string(artifact.to_text()) @@ -46,7 +61,7 @@ def embed_string(self, string: str) -> list[float]: raise RuntimeError("Failed to embed string.") @abstractmethod - def try_embed_chunk(self, chunk: str) -> list[float]: ... + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: ... def _embed_long_string(self, string: str) -> list[float]: """Embeds a string that is too long to embed in one go. diff --git a/griptape/drivers/embedding/cohere_embedding_driver.py b/griptape/drivers/embedding/cohere_embedding_driver.py index b6dc69261..e226a93bb 100644 --- a/griptape/drivers/embedding/cohere_embedding_driver.py +++ b/griptape/drivers/embedding/cohere_embedding_driver.py @@ -39,7 +39,9 @@ class CohereEmbeddingDriver(BaseEmbeddingDriver): def client(self) -> Client: return import_optional_dependency("cohere").Client(self.api_key) - def try_embed_chunk(self, chunk: str) -> list[float]: + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: + if isinstance(chunk, bytes): + raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.") result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type) if isinstance(result.embeddings, list): diff --git a/griptape/drivers/embedding/dummy_embedding_driver.py b/griptape/drivers/embedding/dummy_embedding_driver.py index 1ad79ee03..5a9dc510d 100644 --- a/griptape/drivers/embedding/dummy_embedding_driver.py +++ b/griptape/drivers/embedding/dummy_embedding_driver.py @@ -10,5 +10,5 @@ class DummyEmbeddingDriver(BaseEmbeddingDriver): model: None = field(init=False, default=None, kw_only=True) - def try_embed_chunk(self, chunk: str) -> list[float]: + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: raise DummyError(__class__.__name__, "try_embed_chunk") diff --git a/griptape/drivers/embedding/google_embedding_driver.py b/griptape/drivers/embedding/google_embedding_driver.py index 6524bef62..2c52dddcc 100644 --- a/griptape/drivers/embedding/google_embedding_driver.py +++ b/griptape/drivers/embedding/google_embedding_driver.py @@ -26,7 +26,9 @@ class GoogleEmbeddingDriver(BaseEmbeddingDriver): task_type: str = field(default="retrieval_document", kw_only=True, metadata={"serializable": True}) title: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - def try_embed_chunk(self, chunk: str) -> list[float]: + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: + if isinstance(chunk, bytes): + raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.") genai = import_optional_dependency("google.generativeai") genai.configure(api_key=self.api_key) diff --git a/griptape/drivers/embedding/huggingface_hub_embedding_driver.py b/griptape/drivers/embedding/huggingface_hub_embedding_driver.py index 154d23ef0..8a8fc20d1 100644 --- a/griptape/drivers/embedding/huggingface_hub_embedding_driver.py +++ b/griptape/drivers/embedding/huggingface_hub_embedding_driver.py @@ -32,7 +32,9 @@ def client(self) -> InferenceClient: token=self.api_token, ) - def try_embed_chunk(self, chunk: str) -> list[float]: + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: + if isinstance(chunk, bytes): + raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.") response = self.client.feature_extraction(chunk) return response.flatten().tolist() diff --git a/griptape/drivers/embedding/ollama_embedding_driver.py b/griptape/drivers/embedding/ollama_embedding_driver.py index 7dde53547..9addf6356 100644 --- a/griptape/drivers/embedding/ollama_embedding_driver.py +++ b/griptape/drivers/embedding/ollama_embedding_driver.py @@ -30,5 +30,7 @@ class OllamaEmbeddingDriver(BaseEmbeddingDriver): def client(self) -> Client: return import_optional_dependency("ollama").Client(host=self.host) - def try_embed_chunk(self, chunk: str) -> list[float]: + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: + if isinstance(chunk, bytes): + raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.") return list(self.client.embeddings(model=self.model, prompt=chunk)["embedding"]) diff --git a/griptape/drivers/embedding/openai_embedding_driver.py b/griptape/drivers/embedding/openai_embedding_driver.py index 17fbc9377..9fc74abd4 100644 --- a/griptape/drivers/embedding/openai_embedding_driver.py +++ b/griptape/drivers/embedding/openai_embedding_driver.py @@ -44,7 +44,9 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver): def client(self) -> openai.OpenAI: return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization) - def try_embed_chunk(self, chunk: str) -> list[float]: + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: + if isinstance(chunk, bytes): + raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.") # Address a performance issue in older ada models # https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 if self.model.endswith("001"): diff --git a/griptape/drivers/embedding/voyageai_embedding_driver.py b/griptape/drivers/embedding/voyageai_embedding_driver.py index 30a009cc3..e8fe3abdf 100644 --- a/griptape/drivers/embedding/voyageai_embedding_driver.py +++ b/griptape/drivers/embedding/voyageai_embedding_driver.py @@ -40,5 +40,7 @@ class VoyageAiEmbeddingDriver(BaseEmbeddingDriver): def client(self) -> Any: return import_optional_dependency("voyageai").Client(api_key=self.api_key) - def try_embed_chunk(self, chunk: str) -> list[float]: + def try_embed_chunk(self, chunk: str | bytes) -> list[float]: + if isinstance(chunk, bytes): + raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.") return self.client.embed([chunk], model=self.model, input_type=self.input_type).embeddings[0] diff --git a/griptape/drivers/vector/azure_mongodb_vector_store_driver.py b/griptape/drivers/vector/azure_mongodb_vector_store_driver.py index 44964265d..25ab4788e 100644 --- a/griptape/drivers/vector/azure_mongodb_vector_store_driver.py +++ b/griptape/drivers/vector/azure_mongodb_vector_store_driver.py @@ -61,23 +61,3 @@ def query_vector( ) for doc in collection.aggregate(pipeline) ] - - def query( - self, - query: str, - *, - count: Optional[int] = None, - namespace: Optional[str] = None, - include_vectors: bool = False, - offset: Optional[int] = None, - **kwargs, - ) -> list[BaseVectorStoreDriver.Entry]: - """Queries the MongoDB collection for documents that match the provided query string. - - Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index. - """ - # Using the embedding driver to convert the query string into a vector - vector = self.embedding_driver.embed_string(query) - return self.query_vector( - vector, count=count, namespace=namespace, include_vectors=include_vectors, offset=offset, **kwargs - ) diff --git a/griptape/drivers/vector/base_vector_store_driver.py b/griptape/drivers/vector/base_vector_store_driver.py index 8f63ceac0..27f61227e 100644 --- a/griptape/drivers/vector/base_vector_store_driver.py +++ b/griptape/drivers/vector/base_vector_store_driver.py @@ -2,13 +2,13 @@ import uuid from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional, overload from attrs import define, field from griptape import utils from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact +from griptape.artifacts.image_artifact import ImageArtifact from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin from griptape.mixins.serializable_mixin import SerializableMixin from griptape.utils import with_contextvars @@ -36,18 +36,65 @@ def to_artifact(self) -> BaseArtifact: def upsert_text_artifacts( self, - artifacts: list[TextArtifact] | dict[str, list[TextArtifact]], + artifacts: list[TextArtifact] | dict[str, list[TextArtifact] | list[ImageArtifact]], *, meta: Optional[dict] = None, **kwargs, ) -> list[str] | dict[str, list[str]]: + return self.upsert_collection(artifacts, meta=meta, **kwargs) + + def upsert_text_artifact( + self, + artifact: TextArtifact, + *, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + vector_id: Optional[str] = None, + **kwargs, + ) -> str: + return self.upsert(artifact, namespace=namespace, meta=meta, vector_id=vector_id, **kwargs) + + def upsert_text( + self, + string: str, + *, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + vector_id: Optional[str] = None, + **kwargs, + ) -> str: + return self.upsert(string, namespace=namespace, meta=meta, vector_id=vector_id, **kwargs) + + @overload + def upsert_collection( + self, + artifacts: list[TextArtifact], + *, + meta: Optional[dict] = None, + **kwargs, + ) -> list[str]: ... + + @overload + def upsert_collection( + self, + artifacts: dict[str, list[TextArtifact] | list[ImageArtifact]], + *, + meta: Optional[dict] = None, + **kwargs, + ) -> dict[str, list[str]]: ... + + def upsert_collection( + self, + artifacts: list[TextArtifact] | dict[str, list[TextArtifact] | list[ImageArtifact]], + *, + meta: Optional[dict] = None, + **kwargs, + ): with self.create_futures_executor() as futures_executor: if isinstance(artifacts, list): return utils.execute_futures_list( [ - futures_executor.submit( - with_contextvars(self.upsert_text_artifact), a, namespace=None, meta=meta, **kwargs - ) + futures_executor.submit(with_contextvars(self.upsert), a, namespace=None, meta=meta, **kwargs) for a in artifacts ], ) @@ -61,21 +108,23 @@ def upsert_text_artifacts( futures_dict[namespace].append( futures_executor.submit( - with_contextvars(self.upsert_text_artifact), a, namespace=namespace, meta=meta, **kwargs + with_contextvars(self.upsert), a, namespace=namespace, meta=meta, **kwargs ) ) return utils.execute_futures_list_dict(futures_dict) - def upsert_text_artifact( + def upsert( self, - artifact: TextArtifact, + value: str | TextArtifact | ImageArtifact, *, namespace: Optional[str] = None, meta: Optional[dict] = None, vector_id: Optional[str] = None, **kwargs, ) -> str: + artifact = TextArtifact(value) if isinstance(value, str) else value + meta = {} if meta is None else meta if vector_id is None: @@ -87,23 +136,10 @@ def upsert_text_artifact( else: meta["artifact"] = artifact.to_json() - vector = artifact.embedding or artifact.generate_embedding(self.embedding_driver) + vector = self.embedding_driver.embed_artifact(artifact) return self.upsert_vector(vector, vector_id=vector_id, namespace=namespace, meta=meta, **kwargs) - def upsert_text( - self, - string: str, - *, - namespace: Optional[str] = None, - meta: Optional[dict] = None, - vector_id: Optional[str] = None, - **kwargs, - ) -> str: - return self.upsert_text_artifact( - TextArtifact(string), vector_id=vector_id, namespace=namespace, meta=meta, **kwargs - ) - def does_entry_exist(self, vector_id: str, *, namespace: Optional[str] = None) -> bool: try: return self.load_entry(vector_id, namespace=namespace) is not None @@ -150,14 +186,19 @@ def query_vector( def query( self, - query: str, + query: str | BaseArtifact, *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, **kwargs, ) -> list[Entry]: - vector = self.embedding_driver.embed_string(query) + if isinstance(query, str): + vector = self.embedding_driver.embed_string(query) + elif isinstance(query, (TextArtifact, ImageArtifact, ListArtifact)): + vector = self.embedding_driver.embed_artifact(query) + else: + raise ValueError(f"Unsupported query type: {type(query)}") return self.query_vector(vector, count=count, namespace=namespace, include_vectors=include_vectors, **kwargs) def _get_default_vector_id(self, value: str) -> str: diff --git a/griptape/drivers/vector/dummy_vector_store_driver.py b/griptape/drivers/vector/dummy_vector_store_driver.py index 0f927fd12..b2dc0292a 100644 --- a/griptape/drivers/vector/dummy_vector_store_driver.py +++ b/griptape/drivers/vector/dummy_vector_store_driver.py @@ -9,6 +9,7 @@ from griptape.exceptions import DummyError if TYPE_CHECKING: + from griptape.artifacts import BaseArtifact from griptape.drivers.embedding import BaseEmbeddingDriver @@ -52,7 +53,7 @@ def query_vector( def query( self, - query: str, + query: str | BaseArtifact, *, count: Optional[int] = None, namespace: Optional[str] = None, diff --git a/griptape/drivers/vector/griptape_cloud_vector_store_driver.py b/griptape/drivers/vector/griptape_cloud_vector_store_driver.py index 12ba6aa9d..01ff993d2 100644 --- a/griptape/drivers/vector/griptape_cloud_vector_store_driver.py +++ b/griptape/drivers/vector/griptape_cloud_vector_store_driver.py @@ -7,6 +7,7 @@ import requests from attrs import Factory, define, field +from griptape.artifacts import BaseArtifact from griptape.drivers.embedding.dummy import DummyEmbeddingDriver from griptape.drivers.vector import BaseVectorStoreDriver @@ -83,7 +84,7 @@ def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact: def query( self, - query: str, + query: str | BaseArtifact, *, count: Optional[int] = None, namespace: Optional[str] = None, @@ -97,6 +98,8 @@ def query( Performs a query on the Knowledge Base and returns Artifacts with close vector proximity to the query, optionally filtering to only those that match the provided filter(s). """ + if isinstance(query, BaseArtifact): + raise ValueError(f"{self.__class__.__name__} does not support querying with Artifacts.") url = urljoin(self.base_url.strip("/"), f"/api/knowledge-bases/{self.knowledge_base_id}/query") query_args = { diff --git a/griptape/drivers/vector/marqo_vector_store_driver.py b/griptape/drivers/vector/marqo_vector_store_driver.py index 07d717480..9be749ade 100644 --- a/griptape/drivers/vector/marqo_vector_store_driver.py +++ b/griptape/drivers/vector/marqo_vector_store_driver.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: import marqo - from griptape.artifacts import TextArtifact + from griptape.artifacts import BaseArtifact, TextArtifact @define @@ -165,9 +165,42 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto return entries + def query_vector( + self, + vector: list[float], + *, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + include_metadata: bool = True, + **kwargs, + ) -> list[BaseVectorStoreDriver.Entry]: + """Query the Marqo index for documents. + + Args: + vector: The vector to query by. + count: The maximum number of results to return. + namespace: The namespace to filter results by. + include_vectors: Whether to include vector data in the results. + include_metadata: Whether to include metadata in the results. + kwargs: Additional keyword arguments to pass to the Marqo client. + + Returns: + The list of query results. + """ + params = { + "limit": count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, + "attributes_to_retrieve": None if include_metadata else ["_id"], + "filter_string": f"namespace:{namespace}" if namespace else None, + } | kwargs + + results = self.client.index(self.index).search(**params, context={"tensor": [vector], "weight": 1}) + + return self.__process_results(results, include_vectors=include_vectors) + def query( self, - query: str, + query: str | BaseArtifact, *, count: Optional[int] = None, namespace: Optional[str] = None, @@ -195,22 +228,7 @@ def query( } | kwargs results = self.client.index(self.index).search(query, **params) - - if include_vectors: - results["hits"] = [ - {**r, **self.client.index(self.index).get_document(r["_id"], expose_facets=True)} - for r in results["hits"] - ] - - return [ - BaseVectorStoreDriver.Entry( - id=r["_id"], - vector=r["_tensor_facets"][0]["_embedding"] if include_vectors else [], - score=r["_score"], - meta={k: v for k, v in r.items() if k not in ["_score", "_tensor_facets"]}, - ) - for r in results["hits"] - ] + return self.__process_results(results, include_vectors=include_vectors) def delete_index(self, name: str) -> dict[str, Any]: """Delete an index in the Marqo client. @@ -256,3 +274,20 @@ def upsert_vector( def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.") + + def __process_results(self, results: dict, *, include_vectors: bool) -> list[BaseVectorStoreDriver.Entry]: + if include_vectors: + results["hits"] = [ + {**r, **self.client.index(self.index).get_document(r["_id"], expose_facets=True)} + for r in results["hits"] + ] + + return [ + BaseVectorStoreDriver.Entry( + id=r["_id"], + vector=r["_tensor_facets"][0]["_embedding"] if include_vectors else [], + score=r["_score"], + meta={k: v for k, v in r.items() if k not in ["_score", "_tensor_facets"]}, + ) + for r in results["hits"] + ] diff --git a/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py b/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py index d82f8f768..783725912 100644 --- a/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py +++ b/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py @@ -168,26 +168,6 @@ def query_vector( for doc in collection.aggregate(pipeline) ] - def query( - self, - query: str, - *, - count: Optional[int] = None, - namespace: Optional[str] = None, - include_vectors: bool = False, - offset: Optional[int] = None, - **kwargs, - ) -> list[BaseVectorStoreDriver.Entry]: - """Queries the MongoDB collection for documents that match the provided query string. - - Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index. - """ - # Using the embedding driver to convert the query string into a vector - vector = self.embedding_driver.embed_string(query) - return self.query_vector( - vector, count=count, namespace=namespace, include_vectors=include_vectors, offset=offset, **kwargs - ) - def delete_vector(self, vector_id: str) -> None: """Deletes the vector from the collection.""" collection = self.get_collection() diff --git a/griptape/drivers/vector/opensearch_vector_store_driver.py b/griptape/drivers/vector/opensearch_vector_store_driver.py index 23ecf600f..ade185b0a 100644 --- a/griptape/drivers/vector/opensearch_vector_store_driver.py +++ b/griptape/drivers/vector/opensearch_vector_store_driver.py @@ -164,34 +164,5 @@ def query_vector( for hit in response["hits"]["hits"] ] - def query( - self, - query: str, - *, - count: Optional[int] = None, - namespace: Optional[str] = None, - include_vectors: bool = False, - include_metadata: bool = True, - field_name: str = "vector", - **kwargs, - ) -> list[BaseVectorStoreDriver.Entry]: - """Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided query string. - - Results can be limited using the count parameter and optionally filtered by a namespace. - - Returns: - A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace. - """ - vector = self.embedding_driver.embed_string(query) - return self.query_vector( - vector, - count=count, - namespace=namespace, - include_vectors=include_vectors, - include_metadata=include_metadata, - field_name=field_name, - **kwargs, - ) - def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.") diff --git a/griptape/drivers/vector/pgvector_vector_store_driver.py b/griptape/drivers/vector/pgvector_vector_store_driver.py index ebbdbacde..e3fa9012a 100644 --- a/griptape/drivers/vector/pgvector_vector_store_driver.py +++ b/griptape/drivers/vector/pgvector_vector_store_driver.py @@ -180,27 +180,6 @@ def query_vector( for result in results ] - def query( - self, - query: str, - *, - count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, - namespace: Optional[str] = None, - include_vectors: bool = False, - distance_metric: str = "cosine_distance", - **kwargs, - ) -> list[BaseVectorStoreDriver.Entry]: - """Performs a search on the collection to find vectors similar to the provided input vector, optionally filtering to only those that match the provided namespace.""" - vector = self.embedding_driver.embed_string(query) - return self.query_vector( - vector, - count=count, - namespace=namespace, - include_vectors=include_vectors, - distance_metric=distance_metric, - **kwargs, - ) - def default_vector_model(self) -> Any: pgvector_sqlalchemy = import_optional_dependency("pgvector.sqlalchemy") sqlalchemy = import_optional_dependency("sqlalchemy") diff --git a/griptape/drivers/vector/pinecone_vector_store_driver.py b/griptape/drivers/vector/pinecone_vector_store_driver.py index c6c95d842..d64cceca2 100644 --- a/griptape/drivers/vector/pinecone_vector_store_driver.py +++ b/griptape/drivers/vector/pinecone_vector_store_driver.py @@ -117,25 +117,5 @@ def query_vector( for r in results["matches"] ] - def query( - self, - query: str, - *, - count: Optional[int] = None, - namespace: Optional[str] = None, - include_vectors: bool = False, - include_metadata: bool = True, - **kwargs, - ) -> list[BaseVectorStoreDriver.Entry]: - vector = self.embedding_driver.embed_string(query) - return self.query_vector( - vector, - count=count, - namespace=namespace, - include_vectors=include_vectors, - include_metadata=include_metadata, - **kwargs, - ) - def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.") diff --git a/tests/unit/drivers/embedding/test_amazon_bedrock_cohere_embedding_driver.py b/tests/unit/drivers/embedding/test_amazon_bedrock_cohere_embedding_driver.py index ca9fe2dbe..80856cf64 100644 --- a/tests/unit/drivers/embedding/test_amazon_bedrock_cohere_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_amazon_bedrock_cohere_embedding_driver.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from unittest import mock import pytest @@ -24,5 +25,17 @@ def _mock_session(self, mocker): def test_init(self): assert AmazonBedrockCohereEmbeddingDriver() - def test_try_embed_chunk(self): - assert AmazonBedrockCohereEmbeddingDriver().try_embed_chunk("foobar") == [0, 1, 0] + @pytest.mark.parametrize( + ("chunk", "expected_output", "expected_error"), + [ + ("foobar", [0, 1, 0], nullcontext()), + ( + b"foobar", + [], + pytest.raises(ValueError, match="AmazonBedrockCohereEmbeddingDriver does not support embedding bytes."), + ), + ], + ) + def test_try_embed_chunk(self, chunk, expected_output, expected_error): + with expected_error: + assert AmazonBedrockCohereEmbeddingDriver().try_embed_chunk(chunk) == expected_output diff --git a/tests/unit/drivers/embedding/test_amazon_bedrock_titan_embedding_driver.py b/tests/unit/drivers/embedding/test_amazon_bedrock_titan_embedding_driver.py index d428040d1..79fe964b0 100644 --- a/tests/unit/drivers/embedding/test_amazon_bedrock_titan_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_amazon_bedrock_titan_embedding_driver.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from unittest import mock import pytest @@ -24,5 +25,17 @@ def _mock_session(self, mocker): def test_init(self): assert AmazonBedrockTitanEmbeddingDriver() - def test_try_embed_chunk(self): - assert AmazonBedrockTitanEmbeddingDriver().try_embed_chunk("foobar") == [0, 1, 0] + @pytest.mark.parametrize( + ("chunk", "expected_output", "expected_error"), + [ + ("foobar", [0, 1, 0], nullcontext()), + ( + b"foobar", + [], + pytest.raises(ValueError, match="AmazonBedrockTitanEmbeddingDriver does not support embedding bytes."), + ), + ], + ) + def test_try_embed_chunk(self, chunk, expected_output, expected_error): + with expected_error: + assert AmazonBedrockTitanEmbeddingDriver().try_embed_chunk(chunk) == expected_output diff --git a/tests/unit/drivers/embedding/test_azure_openai_embedding_driver.py b/tests/unit/drivers/embedding/test_azure_openai_embedding_driver.py index adea78130..f7cebfa22 100644 --- a/tests/unit/drivers/embedding/test_azure_openai_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_azure_openai_embedding_driver.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from unittest.mock import Mock import pytest @@ -27,5 +28,17 @@ def test_init(self, driver): assert driver assert AzureOpenAiEmbeddingDriver(azure_endpoint="foobar", model="gpt-4").azure_deployment == "gpt-4" - def test_embed_chunk(self, driver): - assert driver.try_embed_chunk("foobar") == [0, 1, 0] + @pytest.mark.parametrize( + ("chunk", "expected_output", "expected_error"), + [ + ("foobar", [0, 1, 0], nullcontext()), + ( + b"foobar", + [], + pytest.raises(ValueError, match="AzureOpenAiEmbeddingDriver does not support embedding bytes."), + ), + ], + ) + def test_embed_chunk(self, driver, chunk, expected_output, expected_error): + with expected_error: + assert driver.try_embed_chunk(chunk) == expected_output diff --git a/tests/unit/drivers/embedding/test_cohere_embedding_driver.py b/tests/unit/drivers/embedding/test_cohere_embedding_driver.py index 6281e69b8..ca38dd904 100644 --- a/tests/unit/drivers/embedding/test_cohere_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_cohere_embedding_driver.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from unittest.mock import Mock import pytest @@ -17,7 +18,22 @@ def mock_client(self, mocker): def test_init(self): assert CohereEmbeddingDriver(model="embed-english-v3.0", api_key="bar", input_type="search_document") - def test_try_embed_chunk(self): - assert CohereEmbeddingDriver( - model="embed-english-v3.0", api_key="bar", input_type="search_document" - ).try_embed_chunk("foobar") == [0, 1, 0] + @pytest.mark.parametrize( + ("chunk", "expected_output", "expected_error"), + [ + ("foobar", [0, 1, 0], nullcontext()), + ( + b"foobar", + [], + pytest.raises(ValueError, match="CohereEmbeddingDriver does not support embedding bytes."), + ), + ], + ) + def test_try_embed_chunk(self, chunk, expected_output, expected_error): + with expected_error: + assert ( + CohereEmbeddingDriver( + model="embed-english-v3.0", api_key="bar", input_type="search_document" + ).try_embed_chunk(chunk) + == expected_output + ) diff --git a/tests/unit/drivers/embedding/test_google_embedding_driver.py b/tests/unit/drivers/embedding/test_google_embedding_driver.py index 4300b5802..14ba889db 100644 --- a/tests/unit/drivers/embedding/test_google_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_google_embedding_driver.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from unittest.mock import MagicMock import pytest @@ -20,5 +21,17 @@ def mock_genai(self, mocker): def test_init(self): assert GoogleEmbeddingDriver() - def test_try_embed_chunk(self): - assert GoogleEmbeddingDriver().try_embed_chunk("foobar") == [0, 1, 0] + @pytest.mark.parametrize( + ("chunk", "expected_output", "expected_error"), + [ + ("foobar", [0, 1, 0], nullcontext()), + ( + b"foobar", + [], + pytest.raises(ValueError, match="GoogleEmbeddingDriver does not support embedding bytes."), + ), + ], + ) + def test_try_embed_chunk(self, chunk, expected_output, expected_error): + with expected_error: + assert GoogleEmbeddingDriver().try_embed_chunk(chunk) == expected_output diff --git a/tests/unit/drivers/embedding/test_hugging_face_hub_embedding_driver.py b/tests/unit/drivers/embedding/test_hugging_face_hub_embedding_driver.py new file mode 100644 index 000000000..8026f9db8 --- /dev/null +++ b/tests/unit/drivers/embedding/test_hugging_face_hub_embedding_driver.py @@ -0,0 +1,38 @@ +from contextlib import nullcontext +from unittest.mock import Mock + +import pytest + +from griptape.drivers.embedding.huggingface_hub_embedding_driver import HuggingFaceHubEmbeddingDriver + + +class TestHuggingFaceHubEmbeddingDriver: + @pytest.fixture(autouse=True) + def mock_client(self, mocker): + mock_client = mocker.patch("huggingface_hub.InferenceClient").return_value + + mock_response = Mock() + mock_response.flatten().tolist.return_value = [0, 1, 0] + mock_client.feature_extraction.return_value = mock_response + return mock_client + + def test_init(self): + assert HuggingFaceHubEmbeddingDriver(model="embed-english-v3.0", api_token="foo") + + @pytest.mark.parametrize( + ("chunk", "expected_output", "expected_error"), + [ + ("foobar", [0, 1, 0], nullcontext()), + ( + b"foobar", + [], + pytest.raises(ValueError, match="HuggingFaceHubEmbeddingDriver does not support embedding bytes."), + ), + ], + ) + def test_try_embed_chunk(self, chunk, expected_output, expected_error): + with expected_error: + assert ( + HuggingFaceHubEmbeddingDriver(model="embed-english-v3.0", api_token="bar").try_embed_chunk(chunk) + == expected_output + ) diff --git a/tests/unit/drivers/embedding/test_ollama_embedding_driver.py b/tests/unit/drivers/embedding/test_ollama_embedding_driver.py index 4a36d6aa6..cabf999be 100644 --- a/tests/unit/drivers/embedding/test_ollama_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_ollama_embedding_driver.py @@ -1,3 +1,5 @@ +from contextlib import nullcontext + import pytest from griptape.drivers.embedding.ollama import OllamaEmbeddingDriver @@ -15,5 +17,17 @@ def mock_client(self, mocker): def test_init(self): assert OllamaEmbeddingDriver(model="foo") - def test_try_embed_chunk(self): - assert OllamaEmbeddingDriver(model="foo").try_embed_chunk("foobar") == [0, 1, 0] + @pytest.mark.parametrize( + ("chunk", "expected_output", "expected_error"), + [ + ("foobar", [0, 1, 0], nullcontext()), + ( + b"foobar", + [], + pytest.raises(ValueError, match="OllamaEmbeddingDriver does not support embedding bytes."), + ), + ], + ) + def test_try_embed_chunk(self, chunk, expected_output, expected_error): + with expected_error: + assert OllamaEmbeddingDriver(model="foo").try_embed_chunk(chunk) == expected_output diff --git a/tests/unit/drivers/embedding/test_openai_embedding_driver.py b/tests/unit/drivers/embedding/test_openai_embedding_driver.py index 4d5215bf1..1bf7b5547 100644 --- a/tests/unit/drivers/embedding/test_openai_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_openai_embedding_driver.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from unittest.mock import Mock import pytest @@ -23,8 +24,16 @@ def mock_openai(self, mocker): def test_init(self): assert OpenAiEmbeddingDriver() - def test_try_embed_chunk(self): - assert OpenAiEmbeddingDriver().try_embed_chunk("foobar") == [0, 1, 0] + @pytest.mark.parametrize( + ("chunk", "expected_output", "expected_error"), + [ + ("foobar", [0, 1, 0], nullcontext()), + (b"foobar", [], pytest.raises(ValueError, match="OpenAiEmbeddingDriver does not support embedding bytes.")), + ], + ) + def test_try_embed_chunk(self, chunk, expected_output, expected_error): + with expected_error: + assert OpenAiEmbeddingDriver().try_embed_chunk(chunk) == expected_output @pytest.mark.parametrize("model", OpenAiTokenizer.EMBEDDING_MODELS) def test_try_embed_chunk_replaces_newlines_in_older_ada_models(self, model, mock_openai): diff --git a/tests/unit/drivers/embedding/test_sagemaker_jumpstart_embedding_driver.py b/tests/unit/drivers/embedding/test_sagemaker_jumpstart_embedding_driver.py index 525730e88..3e44d2fd2 100644 --- a/tests/unit/drivers/embedding/test_sagemaker_jumpstart_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_sagemaker_jumpstart_embedding_driver.py @@ -57,3 +57,15 @@ def test_try_embed_chunk(self, mock_client): model="test-model", tokenizer=OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL), ).try_embed_chunk("foobar") == [0, 2, 0] + + with pytest.raises( + ValueError, match="AmazonSageMakerJumpstartEmbeddingDriver does not support embedding bytes." + ): + assert ( + AmazonSageMakerJumpstartEmbeddingDriver( + endpoint="test-endpoint", + model="test-model", + tokenizer=OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL), + ).try_embed_chunk(b"foobar") + == [] + ) diff --git a/tests/unit/drivers/embedding/test_voyageai_embedding_driver.py b/tests/unit/drivers/embedding/test_voyageai_embedding_driver.py index 064027938..54de57a55 100644 --- a/tests/unit/drivers/embedding/test_voyageai_embedding_driver.py +++ b/tests/unit/drivers/embedding/test_voyageai_embedding_driver.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from unittest.mock import Mock import pytest @@ -16,5 +17,17 @@ def mock_client(self, mocker): def test_init(self): assert VoyageAiEmbeddingDriver() - def test_try_embed_chunk(self): - assert VoyageAiEmbeddingDriver().try_embed_chunk("foobar") == [0, 1, 0] + @pytest.mark.parametrize( + ("chunk", "expected_output", "expected_error"), + [ + ("foobar", [0, 1, 0], nullcontext()), + ( + b"foobar", + [], + pytest.raises(ValueError, match="VoyageAiEmbeddingDriver does not support embedding bytes."), + ), + ], + ) + def test_try_embed_chunk(self, chunk, expected_output, expected_error): + with expected_error: + assert VoyageAiEmbeddingDriver().try_embed_chunk(chunk) == expected_output diff --git a/tests/unit/drivers/vector/test_griptape_cloud_vector_store_driver.py b/tests/unit/drivers/vector/test_griptape_cloud_vector_store_driver.py index e5b41fe83..667800f96 100644 --- a/tests/unit/drivers/vector/test_griptape_cloud_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_griptape_cloud_vector_store_driver.py @@ -2,6 +2,7 @@ import pytest +from griptape.artifacts.text_artifact import TextArtifact from griptape.drivers.vector.griptape_cloud import GriptapeCloudVectorStoreDriver @@ -73,3 +74,9 @@ def test_query_defaults(self, driver): assert result[1].meta == self.test_metas[1] assert result[0].score == self.test_scores[0] assert result[1].score == self.test_scores[1] + + def test_query_artifact(self, driver): + with pytest.raises( + ValueError, match="GriptapeCloudVectorStoreDriver does not support querying with Artifacts." + ): + driver.query(TextArtifact("some query")) diff --git a/tests/unit/drivers/vector/test_marqo_vector_store_driver.py b/tests/unit/drivers/vector/test_marqo_vector_store_driver.py index 1bf433aa2..91be8094d 100644 --- a/tests/unit/drivers/vector/test_marqo_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_marqo_vector_store_driver.py @@ -117,9 +117,14 @@ def test_upsert_text_artifact(self, driver, mock_marqo): } assert result == expected_return_value["items"][0]["_id"] - def test_query_vector(self, driver): - with pytest.raises(NotImplementedError): - driver.query_vector([0.0, 0.5]) + def test_query_vector(self, driver, mock_marqo): + results = driver.query_vector([0.1, 0.2, 0.3]) + mock_marqo.index().search.assert_called() + assert len(results) == 1 + assert results[0].score == 0.6047464 + assert results[0].meta["Title"] == "Test Title" + assert results[0].meta["Description"] == "Test description" + assert results[0].id == "5aed93eb-3878-4f12-bc92-0fda01c7d23d" def test_search(self, driver, mock_marqo): results = driver.query("Test query")