-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of github.com:awinml/voyage-embedders-haystack in…
…to main
- Loading branch information
Showing
7 changed files
with
607 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# SPDX-FileCopyrightText: 2023-present Ashwin Mathur <> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
__version__ = "1.0.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# SPDX-FileCopyrightText: 2023-present John Doe <jd@example.com> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from voyage_embedders.voyage_document_embedder import VoyageDocumentEmbedder | ||
from voyage_embedders.voyage_text_embedder import VoyageTextEmbedder | ||
|
||
__all__ = ["VoyageDocumentEmbedder", "VoyageTextEmbedder"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import os | ||
from typing import Any, Dict, List, Optional | ||
|
||
import voyageai | ||
from haystack.preview import Document, component, default_to_dict | ||
from tqdm import tqdm | ||
from voyageai import get_embeddings | ||
|
||
MAX_BATCH_SIZE = 8 | ||
|
||
|
||
@component | ||
class VoyageDocumentEmbedder: | ||
""" | ||
A component for computing Document embeddings using Voyage Embedding models. | ||
The embedding of each Document is stored in the `embedding` field of the Document. | ||
Usage example: | ||
```python | ||
from haystack.preview import Document | ||
from haystack.preview.components.embedders import VoyageDocumentEmbedder | ||
doc = Document(text="I love pizza!") | ||
document_embedder = VoyageDocumentEmbedder() | ||
result = document_embedder.run([doc]) | ||
print(result['documents'][0].embedding) | ||
# [0.017020374536514282, -0.023255806416273117, ...] | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
api_key: Optional[str] = None, | ||
model_name: str = "voyage-01", | ||
prefix: str = "", | ||
suffix: str = "", | ||
batch_size: int = 8, | ||
metadata_fields_to_embed: Optional[List[str]] = None, | ||
embedding_separator: str = "\n", | ||
progress_bar: bool = True, # noqa | ||
): | ||
""" | ||
Create a VoyageDocumentEmbedder component. | ||
:param api_key: The VoyageAI API key. It can be explicitly provided or automatically read from the | ||
environment variable VOYAGE_API_KEY (recommended). | ||
:param model_name: The name of the model to use. Defaults to "voyage-01". | ||
For more details on the available models, | ||
see [Voyage Embeddings documentation](https://docs.voyageai.com/embeddings/). | ||
:param prefix: A string to add to the beginning of each text. | ||
:param suffix: A string to add to the end of each text. | ||
:param batch_size: Number of Documents to encode at once. | ||
:param metadata_fields_to_embed: List of meta fields that should be embedded along with the Document text. | ||
:param embedding_separator: Separator used to concatenate the meta fields to the Document text. | ||
:param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments | ||
to keep the logs clean. | ||
""" | ||
# if the user does not provide the API key, check if it is set in the module client | ||
api_key = api_key or voyageai.api_key | ||
if api_key is None: | ||
try: | ||
api_key = os.environ["VOYAGE_API_KEY"] | ||
except KeyError as e: | ||
msg = "VoyageDocumentEmbedder expects an VoyageAI API key. Set the VOYAGE_API_KEY environment variable (recommended) or pass it explicitly." # noqa | ||
raise ValueError(msg) from e | ||
|
||
voyageai.api_key = api_key | ||
|
||
self.model_name = model_name | ||
self.prefix = prefix | ||
self.suffix = suffix | ||
|
||
if batch_size <= MAX_BATCH_SIZE: | ||
self.batch_size = batch_size | ||
else: | ||
err_msg = f"""VoyageDocumentEmbedder has a maximum batch size of {MAX_BATCH_SIZE}. Set the Set the batch_size to {MAX_BATCH_SIZE} or less.""" # noqa | ||
raise ValueError(err_msg) | ||
|
||
self.progress_bar = progress_bar | ||
self.metadata_fields_to_embed = metadata_fields_to_embed or [] | ||
self.embedding_separator = embedding_separator | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
This method overrides the default serializer in order to avoid leaking the `api_key` value passed | ||
to the constructor. | ||
""" | ||
return default_to_dict( | ||
self, | ||
model_name=self.model_name, | ||
prefix=self.prefix, | ||
suffix=self.suffix, | ||
batch_size=self.batch_size, | ||
progress_bar=self.progress_bar, | ||
metadata_fields_to_embed=self.metadata_fields_to_embed, | ||
embedding_separator=self.embedding_separator, | ||
) | ||
|
||
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: | ||
""" | ||
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. | ||
""" | ||
texts_to_embed = [] | ||
for doc in documents: | ||
meta_values_to_embed = [ | ||
str(doc.meta[key]) | ||
for key in self.metadata_fields_to_embed | ||
if key in doc.meta and doc.meta[key] is not None | ||
] | ||
|
||
text_to_embed = ( | ||
self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix | ||
) | ||
|
||
texts_to_embed.append(text_to_embed) | ||
return texts_to_embed | ||
|
||
def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[float]]: | ||
""" | ||
Embed a list of texts in batches. | ||
""" | ||
|
||
all_embeddings = [] | ||
for i in tqdm( | ||
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" | ||
): | ||
batch = texts_to_embed[i : i + batch_size] | ||
embeddings = get_embeddings(list_of_text=batch, batch_size=batch_size, model=self.model_name) | ||
all_embeddings.extend(embeddings) | ||
|
||
return all_embeddings | ||
|
||
@component.output_types(documents=List[Document]) | ||
def run(self, documents: List[Document]): | ||
""" | ||
Embed a list of Documents. | ||
The embedding of each Document is stored in the `embedding` field of the Document. | ||
:param documents: A list of Documents to embed. | ||
""" | ||
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): | ||
msg = "VoyageDocumentEmbedder expects a list of Documents as input.In case you want to embed a string, please use the VoyageTextEmbedder." # noqa | ||
raise TypeError(msg) | ||
|
||
texts_to_embed = self._prepare_texts_to_embed(documents=documents) | ||
|
||
embeddings = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size) | ||
|
||
for doc, emb in zip(documents, embeddings): | ||
doc.embedding = emb | ||
|
||
return {"documents": documents} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import os | ||
from typing import Any, Dict, List, Optional | ||
|
||
import voyageai | ||
from haystack.preview import component, default_to_dict | ||
from voyageai import get_embedding | ||
|
||
|
||
@component | ||
class VoyageTextEmbedder: | ||
""" | ||
A component for embedding strings using Voyage models. | ||
Usage example: | ||
```python | ||
from haystack.preview.components.embedders import VoyageTextEmbedder | ||
text_to_embed = "I love pizza!" | ||
text_embedder = VoyageTextEmbedder() | ||
print(text_embedder.run(text_to_embed)) | ||
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...], | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
api_key: Optional[str] = None, | ||
model_name: str = "voyage-01", | ||
prefix: str = "", | ||
suffix: str = "", | ||
): | ||
""" | ||
Create an VoyageTextEmbedder component. | ||
:param api_key: The VoyageAI API key. It can be explicitly provided or automatically read from the | ||
environment variable VOYAGE_API_KEY (recommended). | ||
:param model_name: The name of the Voyage model to use. Defaults to "voyage-01". | ||
For more details on the available models, | ||
see [Voyage Embeddings documentation](https://docs.voyageai.com/embeddings/). | ||
:param prefix: A string to add to the beginning of each text. | ||
:param suffix: A string to add to the end of each text. | ||
""" | ||
# if the user does not provide the API key, check if it is set in the module client | ||
api_key = api_key or voyageai.api_key | ||
if api_key is None: | ||
try: | ||
api_key = os.environ["VOYAGE_API_KEY"] | ||
except KeyError as e: | ||
msg = "VoyageTextEmbedder expects an VoyageAI API key. Set the VOYAGE_API_KEY environment variable (recommended) or pass it explicitly." # noqa | ||
raise ValueError(msg) from e | ||
|
||
voyageai.api_key = api_key | ||
|
||
self.model_name = model_name | ||
self.prefix = prefix | ||
self.suffix = suffix | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
This method overrides the default serializer in order to avoid leaking the `api_key` value passed | ||
to the constructor. | ||
""" | ||
|
||
return default_to_dict(self, model_name=self.model_name, prefix=self.prefix, suffix=self.suffix) | ||
|
||
@component.output_types(embedding=List[float]) | ||
def run(self, text: str): | ||
"""Embed a string.""" | ||
if not isinstance(text, str): | ||
msg = "VoyageTextEmbedder expects a string as an input.In case you want to embed a list of Documents, please use the VoyageDocumentEmbedder." # noqa | ||
raise TypeError(msg) | ||
|
||
text_to_embed = self.prefix + text + self.suffix | ||
|
||
embedding = get_embedding(text=text_to_embed, model=self.model_name) | ||
|
||
return {"embedding": embedding} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# SPDX-FileCopyrightText: 2023-present Ashwin Mathur <> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 |
Oops, something went wrong.