Skip to content

Commit

Permalink
feat(drivers-vector-local): implement more robust persist_file saving…
Browse files Browse the repository at this point in the history
… that supports multi-modal artifacts

feat(drivers-vector): support upserting/querying `ImageArtifact`s
feat(drivers-embedding): support embedding `ImageArtifact`s
  • Loading branch information
collindutter committed Feb 20, 2025
1 parent 762958f commit d838100
Show file tree
Hide file tree
Showing 34 changed files with 385 additions and 212 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver):
def client(self) -> BedrockClient:
return self.session.client("bedrock-runtime")

def try_embed_chunk(self, chunk: str) -> list[float]:
def try_embed_chunk(self, chunk: str | bytes) -> list[float]:
if isinstance(chunk, bytes):
raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.")
payload = {"input_type": self.input_type, "texts": [chunk]}

response = self.client.invoke_model(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
def client(self) -> BedrockClient:
return self.session.client("bedrock-runtime")

def try_embed_chunk(self, chunk: str) -> list[float]:
def try_embed_chunk(self, chunk: str | bytes) -> list[float]:
if isinstance(chunk, bytes):
raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.")
payload = {"inputText": chunk}

response = self.client.invoke_model(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver):
def client(self) -> SageMakerClient:
return self.session.client("sagemaker-runtime")

def try_embed_chunk(self, chunk: str) -> list[float]:
def try_embed_chunk(self, chunk: str | bytes) -> list[float]:
if isinstance(chunk, bytes):
raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.")
payload = {"text_inputs": chunk, "mode": "embedding"}

endpoint_response = self.client.invoke_endpoint(
Expand Down
30 changes: 22 additions & 8 deletions griptape/drivers/embedding/base_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import numpy as np
from attrs import define, field

from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact
from griptape.chunkers import BaseChunker, TextChunker
from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin
from griptape.mixins.serializable_mixin import SerializableMixin

if TYPE_CHECKING:
from griptape.artifacts import TextArtifact
from griptape.tokenizers import BaseTokenizer


Expand All @@ -35,18 +35,32 @@ def embed_text_artifact(self, artifact: TextArtifact) -> list[float]:
return self.embed_string(artifact.to_text())

def embed_string(self, string: str) -> list[float]:
if self.tokenizer is not None and self.tokenizer.count_tokens(string) > self.tokenizer.max_input_tokens:
return self._embed_long_string(string)
else:
return self.try_embed_chunk(string)

def embed(
self, value: str | bytes | TextArtifact | ImageArtifact | ListArtifact[TextArtifact | ImageArtifact]
) -> list[float]:
for attempt in self.retrying():
with attempt:
if self.tokenizer is not None and self.tokenizer.count_tokens(string) > self.tokenizer.max_input_tokens:
return self._embed_long_string(string)
else:
return self.try_embed_chunk(string)

if isinstance(value, str):
return self.embed_string(value)
elif isinstance(value, bytes):
return self.try_embed_chunk(value)
elif isinstance(value, (TextArtifact, ImageArtifact)):
return self.embed(value.value)
elif isinstance(value, ListArtifact):
# If multiple artifacts are provided, embed each one and average the results
embeddings = [self.embed(artifact) for artifact in value]

return np.average(embeddings, axis=0).tolist()
else:
raise RuntimeError("Failed to embed string.")
raise ValueError("Failed to embed value.")

Check warning on line 60 in griptape/drivers/embedding/base_embedding_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/embedding/base_embedding_driver.py#L60

Added line #L60 was not covered by tests

@abstractmethod
def try_embed_chunk(self, chunk: str) -> list[float]: ...
def try_embed_chunk(self, chunk: str | bytes) -> list[float]: ...

def _embed_long_string(self, string: str) -> list[float]:
"""Embeds a string that is too long to embed in one go.
Expand Down
4 changes: 3 additions & 1 deletion griptape/drivers/embedding/cohere_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class CohereEmbeddingDriver(BaseEmbeddingDriver):
def client(self) -> Client:
return import_optional_dependency("cohere").Client(self.api_key)

def try_embed_chunk(self, chunk: str) -> list[float]:
def try_embed_chunk(self, chunk: str | bytes) -> list[float]:
if isinstance(chunk, bytes):
raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.")
result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type)

if isinstance(result.embeddings, list):
Expand Down
2 changes: 1 addition & 1 deletion griptape/drivers/embedding/dummy_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
class DummyEmbeddingDriver(BaseEmbeddingDriver):
model: None = field(init=False, default=None, kw_only=True)

def try_embed_chunk(self, chunk: str) -> list[float]:
def try_embed_chunk(self, chunk: str | bytes) -> list[float]:
raise DummyError(__class__.__name__, "try_embed_chunk")
4 changes: 3 additions & 1 deletion griptape/drivers/embedding/google_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ class GoogleEmbeddingDriver(BaseEmbeddingDriver):
task_type: str = field(default="retrieval_document", kw_only=True, metadata={"serializable": True})
title: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})

def try_embed_chunk(self, chunk: str) -> list[float]:
def try_embed_chunk(self, chunk: str | bytes) -> list[float]:
if isinstance(chunk, bytes):
raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.")
genai = import_optional_dependency("google.generativeai")
genai.configure(api_key=self.api_key)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def client(self) -> InferenceClient:
token=self.api_token,
)

def try_embed_chunk(self, chunk: str) -> list[float]:
def try_embed_chunk(self, chunk: str | bytes) -> list[float]:
if isinstance(chunk, bytes):
raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.")
response = self.client.feature_extraction(chunk)

return response.flatten().tolist()
4 changes: 3 additions & 1 deletion griptape/drivers/embedding/ollama_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,7 @@ class OllamaEmbeddingDriver(BaseEmbeddingDriver):
def client(self) -> Client:
return import_optional_dependency("ollama").Client(host=self.host)

def try_embed_chunk(self, chunk: str) -> list[float]:
def try_embed_chunk(self, chunk: str | bytes) -> list[float]:
if isinstance(chunk, bytes):
raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.")
return list(self.client.embeddings(model=self.model, prompt=chunk)["embedding"])
4 changes: 3 additions & 1 deletion griptape/drivers/embedding/openai_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
def client(self) -> openai.OpenAI:
return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)

def try_embed_chunk(self, chunk: str) -> list[float]:
def try_embed_chunk(self, chunk: str | bytes) -> list[float]:
if isinstance(chunk, bytes):
raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.")
# Address a performance issue in older ada models
# https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
if self.model.endswith("001"):
Expand Down
4 changes: 3 additions & 1 deletion griptape/drivers/embedding/voyageai_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,7 @@ class VoyageAiEmbeddingDriver(BaseEmbeddingDriver):
def client(self) -> Any:
return import_optional_dependency("voyageai").Client(api_key=self.api_key)

def try_embed_chunk(self, chunk: str) -> list[float]:
def try_embed_chunk(self, chunk: str | bytes) -> list[float]:
if isinstance(chunk, bytes):
raise ValueError(f"{self.__class__.__name__} does not support embedding bytes.")
return self.client.embed([chunk], model=self.model, input_type=self.input_type).embeddings[0]
20 changes: 0 additions & 20 deletions griptape/drivers/vector/azure_mongodb_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,3 @@ def query_vector(
)
for doc in collection.aggregate(pipeline)
]

def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
offset: Optional[int] = None,
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
"""Queries the MongoDB collection for documents that match the provided query string.
Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
"""
# Using the embedding driver to convert the query string into a vector
vector = self.embedding_driver.embed_string(query)
return self.query_vector(
vector, count=count, namespace=namespace, include_vectors=include_vectors, offset=offset, **kwargs
)
105 changes: 68 additions & 37 deletions griptape/drivers/vector/base_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Optional, overload

from attrs import define, field

from griptape import utils
from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact
from griptape.artifacts import BaseArtifact, ImageArtifact, ListArtifact, TextArtifact
from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin
from griptape.mixins.serializable_mixin import SerializableMixin
from griptape.utils import with_contextvars
Expand All @@ -21,17 +20,13 @@
class BaseVectorStoreDriver(SerializableMixin, FuturesExecutorMixin, ABC):
DEFAULT_QUERY_COUNT = 5

@dataclass
class Entry:
id: str
vector: Optional[list[float]] = None
score: Optional[float] = None
meta: Optional[dict] = None
namespace: Optional[str] = None

@staticmethod
def from_dict(data: dict[str, Any]) -> BaseVectorStoreDriver.Entry:
return BaseVectorStoreDriver.Entry(**data)
@define
class Entry(SerializableMixin):
id: str = field(metadata={"serializable": True})
vector: Optional[list[float]] = field(default=None, metadata={"serializable": True})
score: Optional[float] = field(default=None, metadata={"serializable": True})
meta: Optional[dict] = field(default=None, metadata={"serializable": True})
namespace: Optional[str] = field(default=None, metadata={"serializable": True})

def to_artifact(self) -> BaseArtifact:
return BaseArtifact.from_json(self.meta["artifact"]) # pyright: ignore[reportOptionalSubscript]
Expand All @@ -40,18 +35,65 @@ def to_artifact(self) -> BaseArtifact:

def upsert_text_artifacts(
self,
artifacts: list[TextArtifact] | dict[str, list[TextArtifact]],
artifacts: list[TextArtifact] | dict[str, list[TextArtifact] | list[ImageArtifact]],
*,
meta: Optional[dict] = None,
**kwargs,
) -> list[str] | dict[str, list[str]]:
return self.upsert_collection(artifacts, meta=meta, **kwargs)

def upsert_text_artifact(
self,
artifact: TextArtifact,
*,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
vector_id: Optional[str] = None,
**kwargs,
) -> str:
return self.upsert(artifact, namespace=namespace, meta=meta, vector_id=vector_id, **kwargs)

def upsert_text(
self,
string: str,
*,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
vector_id: Optional[str] = None,
**kwargs,
) -> str:
return self.upsert(string, namespace=namespace, meta=meta, vector_id=vector_id, **kwargs)

@overload
def upsert_collection(
self,
artifacts: list[TextArtifact],
*,
meta: Optional[dict] = None,
**kwargs,
) -> list[str]: ...

@overload
def upsert_collection(
self,
artifacts: dict[str, list[TextArtifact] | list[ImageArtifact]],
*,
meta: Optional[dict] = None,
**kwargs,
) -> dict[str, list[str]]: ...

def upsert_collection(
self,
artifacts: list[TextArtifact] | dict[str, list[TextArtifact] | list[ImageArtifact]],
*,
meta: Optional[dict] = None,
**kwargs,
):
with self.create_futures_executor() as futures_executor:
if isinstance(artifacts, list):
return utils.execute_futures_list(
[
futures_executor.submit(
with_contextvars(self.upsert_text_artifact), a, namespace=None, meta=meta, **kwargs
)
futures_executor.submit(with_contextvars(self.upsert), a, namespace=None, meta=meta, **kwargs)
for a in artifacts
],
)
Expand All @@ -65,21 +107,23 @@ def upsert_text_artifacts(

futures_dict[namespace].append(
futures_executor.submit(
with_contextvars(self.upsert_text_artifact), a, namespace=namespace, meta=meta, **kwargs
with_contextvars(self.upsert), a, namespace=namespace, meta=meta, **kwargs
)
)

return utils.execute_futures_list_dict(futures_dict)

def upsert_text_artifact(
def upsert(
self,
artifact: TextArtifact,
value: str | TextArtifact | ImageArtifact,
*,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
vector_id: Optional[str] = None,
**kwargs,
) -> str:
artifact = TextArtifact(value) if isinstance(value, str) else value

meta = {} if meta is None else meta

if vector_id is None:
Expand All @@ -91,23 +135,10 @@ def upsert_text_artifact(
else:
meta["artifact"] = artifact.to_json()

vector = artifact.embedding or artifact.generate_embedding(self.embedding_driver)
vector = self.embedding_driver.embed(artifact)

return self.upsert_vector(vector, vector_id=vector_id, namespace=namespace, meta=meta, **kwargs)

def upsert_text(
self,
string: str,
*,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
vector_id: Optional[str] = None,
**kwargs,
) -> str:
return self.upsert_text_artifact(
TextArtifact(string), vector_id=vector_id, namespace=namespace, meta=meta, **kwargs
)

def does_entry_exist(self, vector_id: str, *, namespace: Optional[str] = None) -> bool:
try:
return self.load_entry(vector_id, namespace=namespace) is not None
Expand Down Expand Up @@ -154,14 +185,14 @@ def query_vector(

def query(
self,
query: str,
query: str | TextArtifact | ImageArtifact,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs,
) -> list[Entry]:
vector = self.embedding_driver.embed_string(query)
vector = self.embedding_driver.embed(query)
return self.query_vector(vector, count=count, namespace=namespace, include_vectors=include_vectors, **kwargs)

def _get_default_vector_id(self, value: str) -> str:
Expand Down
3 changes: 2 additions & 1 deletion griptape/drivers/vector/dummy_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from griptape.exceptions import DummyError

if TYPE_CHECKING:
from griptape.artifacts import BaseArtifact
from griptape.drivers.embedding import BaseEmbeddingDriver


Expand Down Expand Up @@ -52,7 +53,7 @@ def query_vector(

def query(
self,
query: str,
query: str | BaseArtifact,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
Expand Down
Loading

0 comments on commit d838100

Please sign in to comment.