Skip to content

Commit

Permalink
feat(drivers-vector): support upserting/querying ImageArtifacts
Browse files Browse the repository at this point in the history
feat(drivers-embedding): support embedding `ImageArtifact`s
  • Loading branch information
collindutter committed Feb 20, 2025
1 parent 7155088 commit 79e7dcf
Show file tree
Hide file tree
Showing 20 changed files with 170 additions and 167 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver):
def client(self) -> BedrockClient:
return self.session.client("bedrock-runtime")

def try_embed_chunk(self, chunk: str) -> list[float]:
def try_embed_chunk(self, chunk: str | bytes) -> list[float]:
if isinstance(chunk, bytes):
raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.")

Check warning on line 51 in griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py#L51

Added line #L51 was not covered by tests
payload = {"input_type": self.input_type, "texts": [chunk]}

response = self.client.invoke_model(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
def client(self) -> BedrockClient:
return self.session.client("bedrock-runtime")

def try_embed_chunk(self, chunk: str) -> list[float]:
def try_embed_chunk(self, chunk: str | bytes) -> list[float]:
if isinstance(chunk, bytes):
raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.")

Check warning on line 47 in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py#L47

Added line #L47 was not covered by tests
payload = {"inputText": chunk}

response = self.client.invoke_model(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver):
def client(self) -> SageMakerClient:
return self.session.client("sagemaker-runtime")

def try_embed_chunk(self, chunk: str) -> list[float]:
def try_embed_chunk(self, chunk: str | bytes) -> list[float]:
if isinstance(chunk, bytes):
raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.")

Check warning on line 31 in griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py#L31

Added line #L31 was not covered by tests
payload = {"text_inputs": chunk, "mode": "embedding"}

endpoint_response = self.client.invoke_endpoint(
Expand Down
19 changes: 17 additions & 2 deletions griptape/drivers/embedding/base_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import numpy as np
from attrs import define, field

from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact
from griptape.chunkers import BaseChunker, TextChunker
from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin
from griptape.mixins.serializable_mixin import SerializableMixin

if TYPE_CHECKING:
from griptape.artifacts import TextArtifact
from griptape.tokenizers import BaseTokenizer


Expand All @@ -31,6 +31,21 @@ class BaseEmbeddingDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
def __attrs_post_init__(self) -> None:
self.chunker = TextChunker(tokenizer=self.tokenizer) if self.tokenizer else None

def embed_artifact(
self, artifact: TextArtifact | ImageArtifact | ListArtifact[TextArtifact | ImageArtifact]
) -> list[float]:
if isinstance(artifact, TextArtifact):
return self.embed_text_artifact(artifact)
elif isinstance(artifact, ImageArtifact):
return self.embed_image_artifact(artifact)

Check warning on line 40 in griptape/drivers/embedding/base_embedding_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/embedding/base_embedding_driver.py#L40

Added line #L40 was not covered by tests
else:
embeddings = [self.embed_artifact(artifact) for artifact in artifact]

Check warning on line 42 in griptape/drivers/embedding/base_embedding_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/embedding/base_embedding_driver.py#L42

Added line #L42 was not covered by tests

return np.average(embeddings, axis=0).tolist()

Check warning on line 44 in griptape/drivers/embedding/base_embedding_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/embedding/base_embedding_driver.py#L44

Added line #L44 was not covered by tests

def embed_image_artifact(self, artifact: ImageArtifact) -> list[float]:
return self.try_embed_chunk(artifact.value)

Check warning on line 47 in griptape/drivers/embedding/base_embedding_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/embedding/base_embedding_driver.py#L47

Added line #L47 was not covered by tests

def embed_text_artifact(self, artifact: TextArtifact) -> list[float]:
return self.embed_string(artifact.to_text())

Expand All @@ -46,7 +61,7 @@ def embed_string(self, string: str) -> list[float]:
raise RuntimeError("Failed to embed string.")

@abstractmethod
def try_embed_chunk(self, chunk: str) -> list[float]: ...
def try_embed_chunk(self, chunk: str | bytes) -> list[float]: ...

def _embed_long_string(self, string: str) -> list[float]:
"""Embeds a string that is too long to embed in one go.
Expand Down
4 changes: 3 additions & 1 deletion griptape/drivers/embedding/cohere_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class CohereEmbeddingDriver(BaseEmbeddingDriver):
def client(self) -> Client:
return import_optional_dependency("cohere").Client(self.api_key)

def try_embed_chunk(self, chunk: str) -> list[float]:
def try_embed_chunk(self, chunk: str | bytes) -> list[float]:
if isinstance(chunk, bytes):
raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.")

Check warning on line 44 in griptape/drivers/embedding/cohere_embedding_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/embedding/cohere_embedding_driver.py#L44

Added line #L44 was not covered by tests
result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type)

if isinstance(result.embeddings, list):
Expand Down
2 changes: 1 addition & 1 deletion griptape/drivers/embedding/dummy_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
class DummyEmbeddingDriver(BaseEmbeddingDriver):
model: None = field(init=False, default=None, kw_only=True)

def try_embed_chunk(self, chunk: str) -> list[float]:
def try_embed_chunk(self, chunk: str | bytes) -> list[float]:
raise DummyError(__class__.__name__, "try_embed_chunk")
4 changes: 3 additions & 1 deletion griptape/drivers/embedding/google_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ class GoogleEmbeddingDriver(BaseEmbeddingDriver):
task_type: str = field(default="retrieval_document", kw_only=True, metadata={"serializable": True})
title: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})

def try_embed_chunk(self, chunk: str) -> list[float]:
def try_embed_chunk(self, chunk: str | bytes) -> list[float]:
if isinstance(chunk, bytes):
raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.")

Check warning on line 31 in griptape/drivers/embedding/google_embedding_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/embedding/google_embedding_driver.py#L31

Added line #L31 was not covered by tests
genai = import_optional_dependency("google.generativeai")
genai.configure(api_key=self.api_key)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def client(self) -> InferenceClient:
token=self.api_token,
)

def try_embed_chunk(self, chunk: str) -> list[float]:
def try_embed_chunk(self, chunk: str | bytes) -> list[float]:
if isinstance(chunk, bytes):
raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.")

Check warning on line 37 in griptape/drivers/embedding/huggingface_hub_embedding_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/embedding/huggingface_hub_embedding_driver.py#L37

Added line #L37 was not covered by tests
response = self.client.feature_extraction(chunk)

return response.flatten().tolist()
4 changes: 3 additions & 1 deletion griptape/drivers/embedding/ollama_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,7 @@ class OllamaEmbeddingDriver(BaseEmbeddingDriver):
def client(self) -> Client:
return import_optional_dependency("ollama").Client(host=self.host)

def try_embed_chunk(self, chunk: str) -> list[float]:
def try_embed_chunk(self, chunk: str | bytes) -> list[float]:
if isinstance(chunk, bytes):
raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.")

Check warning on line 35 in griptape/drivers/embedding/ollama_embedding_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/embedding/ollama_embedding_driver.py#L35

Added line #L35 was not covered by tests
return list(self.client.embeddings(model=self.model, prompt=chunk)["embedding"])
4 changes: 3 additions & 1 deletion griptape/drivers/embedding/openai_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
def client(self) -> openai.OpenAI:
return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)

def try_embed_chunk(self, chunk: str) -> list[float]:
def try_embed_chunk(self, chunk: str | bytes) -> list[float]:
if isinstance(chunk, bytes):
raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.")

Check warning on line 49 in griptape/drivers/embedding/openai_embedding_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/embedding/openai_embedding_driver.py#L49

Added line #L49 was not covered by tests
# Address a performance issue in older ada models
# https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
if self.model.endswith("001"):
Expand Down
4 changes: 3 additions & 1 deletion griptape/drivers/embedding/voyageai_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,7 @@ class VoyageAiEmbeddingDriver(BaseEmbeddingDriver):
def client(self) -> Any:
return import_optional_dependency("voyageai").Client(api_key=self.api_key)

def try_embed_chunk(self, chunk: str) -> list[float]:
def try_embed_chunk(self, chunk: str | bytes) -> list[float]:
if isinstance(chunk, bytes):
raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.")

Check warning on line 45 in griptape/drivers/embedding/voyageai_embedding_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/embedding/voyageai_embedding_driver.py#L45

Added line #L45 was not covered by tests
return self.client.embed([chunk], model=self.model, input_type=self.input_type).embeddings[0]
20 changes: 0 additions & 20 deletions griptape/drivers/vector/azure_mongodb_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,3 @@ def query_vector(
)
for doc in collection.aggregate(pipeline)
]

def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
offset: Optional[int] = None,
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
"""Queries the MongoDB collection for documents that match the provided query string.
Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
"""
# Using the embedding driver to convert the query string into a vector
vector = self.embedding_driver.embed_string(query)
return self.query_vector(
vector, count=count, namespace=namespace, include_vectors=include_vectors, offset=offset, **kwargs
)
91 changes: 66 additions & 25 deletions griptape/drivers/vector/base_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Optional, overload

from attrs import define, field

from griptape import utils
from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact
from griptape.artifacts.image_artifact import ImageArtifact
from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin
from griptape.mixins.serializable_mixin import SerializableMixin
from griptape.utils import with_contextvars
Expand Down Expand Up @@ -36,18 +36,65 @@ def to_artifact(self) -> BaseArtifact:

def upsert_text_artifacts(
self,
artifacts: list[TextArtifact] | dict[str, list[TextArtifact]],
artifacts: list[TextArtifact] | dict[str, list[TextArtifact] | list[ImageArtifact]],
*,
meta: Optional[dict] = None,
**kwargs,
) -> list[str] | dict[str, list[str]]:
return self.upsert_collection(artifacts, meta=meta, **kwargs)

def upsert_text_artifact(
self,
artifact: TextArtifact,
*,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
vector_id: Optional[str] = None,
**kwargs,
) -> str:
return self.upsert(artifact, namespace=namespace, meta=meta, vector_id=vector_id, **kwargs)

def upsert_text(
self,
string: str,
*,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
vector_id: Optional[str] = None,
**kwargs,
) -> str:
return self.upsert(string, namespace=namespace, meta=meta, vector_id=vector_id, **kwargs)

@overload
def upsert_collection(
self,
artifacts: list[TextArtifact],
*,
meta: Optional[dict] = None,
**kwargs,
) -> list[str]: ...

@overload
def upsert_collection(
self,
artifacts: dict[str, list[TextArtifact] | list[ImageArtifact]],
*,
meta: Optional[dict] = None,
**kwargs,
) -> dict[str, list[str]]: ...

def upsert_collection(
self,
artifacts: list[TextArtifact] | dict[str, list[TextArtifact] | list[ImageArtifact]],
*,
meta: Optional[dict] = None,
**kwargs,
):
with self.create_futures_executor() as futures_executor:
if isinstance(artifacts, list):
return utils.execute_futures_list(
[
futures_executor.submit(
with_contextvars(self.upsert_text_artifact), a, namespace=None, meta=meta, **kwargs
)
futures_executor.submit(with_contextvars(self.upsert), a, namespace=None, meta=meta, **kwargs)
for a in artifacts
],
)
Expand All @@ -61,21 +108,23 @@ def upsert_text_artifacts(

futures_dict[namespace].append(
futures_executor.submit(
with_contextvars(self.upsert_text_artifact), a, namespace=namespace, meta=meta, **kwargs
with_contextvars(self.upsert), a, namespace=namespace, meta=meta, **kwargs
)
)

return utils.execute_futures_list_dict(futures_dict)

def upsert_text_artifact(
def upsert(
self,
artifact: TextArtifact,
value: str | TextArtifact | ImageArtifact,
*,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
vector_id: Optional[str] = None,
**kwargs,
) -> str:
artifact = TextArtifact(value) if isinstance(value, str) else value

meta = {} if meta is None else meta

if vector_id is None:
Expand All @@ -87,23 +136,10 @@ def upsert_text_artifact(
else:
meta["artifact"] = artifact.to_json()

vector = artifact.embedding or artifact.generate_embedding(self.embedding_driver)
vector = self.embedding_driver.embed_artifact(artifact)

return self.upsert_vector(vector, vector_id=vector_id, namespace=namespace, meta=meta, **kwargs)

def upsert_text(
self,
string: str,
*,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
vector_id: Optional[str] = None,
**kwargs,
) -> str:
return self.upsert_text_artifact(
TextArtifact(string), vector_id=vector_id, namespace=namespace, meta=meta, **kwargs
)

def does_entry_exist(self, vector_id: str, *, namespace: Optional[str] = None) -> bool:
try:
return self.load_entry(vector_id, namespace=namespace) is not None
Expand Down Expand Up @@ -150,14 +186,19 @@ def query_vector(

def query(
self,
query: str,
query: str | BaseArtifact,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs,
) -> list[Entry]:
vector = self.embedding_driver.embed_string(query)
if isinstance(query, str):
vector = self.embedding_driver.embed_string(query)
elif isinstance(query, (TextArtifact, ImageArtifact, ListArtifact)):
vector = self.embedding_driver.embed_artifact(query)

Check warning on line 199 in griptape/drivers/vector/base_vector_store_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/vector/base_vector_store_driver.py#L199

Added line #L199 was not covered by tests
else:
raise ValueError(f"Unsupported query type: {type(query)}")

Check warning on line 201 in griptape/drivers/vector/base_vector_store_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/vector/base_vector_store_driver.py#L201

Added line #L201 was not covered by tests
return self.query_vector(vector, count=count, namespace=namespace, include_vectors=include_vectors, **kwargs)

def _get_default_vector_id(self, value: str) -> str:
Expand Down
3 changes: 2 additions & 1 deletion griptape/drivers/vector/dummy_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from griptape.exceptions import DummyError

if TYPE_CHECKING:
from griptape.artifacts import BaseArtifact
from griptape.drivers.embedding import BaseEmbeddingDriver


Expand Down Expand Up @@ -52,7 +53,7 @@ def query_vector(

def query(
self,
query: str,
query: str | BaseArtifact,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import requests
from attrs import Factory, define, field

from griptape.artifacts import BaseArtifact
from griptape.drivers.embedding.dummy import DummyEmbeddingDriver
from griptape.drivers.vector import BaseVectorStoreDriver

Expand Down Expand Up @@ -83,7 +84,7 @@ def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact:

def query(
self,
query: str,
query: str | BaseArtifact,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
Expand All @@ -97,6 +98,8 @@ def query(
Performs a query on the Knowledge Base and returns Artifacts with close vector proximity to the query, optionally filtering to only those that match the provided filter(s).
"""
if isinstance(query, BaseArtifact):
raise ValueError(f"{self.__class__.__name__} does not support querying with Artifacts.")

Check warning on line 102 in griptape/drivers/vector/griptape_cloud_vector_store_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/vector/griptape_cloud_vector_store_driver.py#L102

Added line #L102 was not covered by tests
url = urljoin(self.base_url.strip("/"), f"/api/knowledge-bases/{self.knowledge_base_id}/query")

query_args = {
Expand Down
Loading

0 comments on commit 79e7dcf

Please sign in to comment.