Skip to content

Commit

Permalink
Merge branch 'main' of github.com:awinml/voyage-embedders-haystack in…
Browse files Browse the repository at this point in the history
…to main
  • Loading branch information
awinml committed Nov 16, 2023
2 parents 57baf83 + e0b7c21 commit a3c9147
Show file tree
Hide file tree
Showing 7 changed files with 607 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/voyage_embedders/__about__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# SPDX-FileCopyrightText: 2023-present Ashwin Mathur <>
#
# SPDX-License-Identifier: Apache-2.0
__version__ = "1.0.0"
8 changes: 8 additions & 0 deletions src/voyage_embedders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SPDX-FileCopyrightText: 2023-present John Doe <jd@example.com>
#
# SPDX-License-Identifier: Apache-2.0

from voyage_embedders.voyage_document_embedder import VoyageDocumentEmbedder
from voyage_embedders.voyage_text_embedder import VoyageTextEmbedder

__all__ = ["VoyageDocumentEmbedder", "VoyageTextEmbedder"]
154 changes: 154 additions & 0 deletions src/voyage_embedders/voyage_document_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import os
from typing import Any, Dict, List, Optional

import voyageai
from haystack.preview import Document, component, default_to_dict
from tqdm import tqdm
from voyageai import get_embeddings

MAX_BATCH_SIZE = 8


@component
class VoyageDocumentEmbedder:
"""
A component for computing Document embeddings using Voyage Embedding models.
The embedding of each Document is stored in the `embedding` field of the Document.
Usage example:
```python
from haystack.preview import Document
from haystack.preview.components.embedders import VoyageDocumentEmbedder
doc = Document(text="I love pizza!")
document_embedder = VoyageDocumentEmbedder()
result = document_embedder.run([doc])
print(result['documents'][0].embedding)
# [0.017020374536514282, -0.023255806416273117, ...]
```
"""

def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "voyage-01",
prefix: str = "",
suffix: str = "",
batch_size: int = 8,
metadata_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
progress_bar: bool = True, # noqa
):
"""
Create a VoyageDocumentEmbedder component.
:param api_key: The VoyageAI API key. It can be explicitly provided or automatically read from the
environment variable VOYAGE_API_KEY (recommended).
:param model_name: The name of the model to use. Defaults to "voyage-01".
For more details on the available models,
see [Voyage Embeddings documentation](https://docs.voyageai.com/embeddings/).
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param batch_size: Number of Documents to encode at once.
:param metadata_fields_to_embed: List of meta fields that should be embedded along with the Document text.
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
:param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments
to keep the logs clean.
"""
# if the user does not provide the API key, check if it is set in the module client
api_key = api_key or voyageai.api_key
if api_key is None:
try:
api_key = os.environ["VOYAGE_API_KEY"]
except KeyError as e:
msg = "VoyageDocumentEmbedder expects an VoyageAI API key. Set the VOYAGE_API_KEY environment variable (recommended) or pass it explicitly." # noqa
raise ValueError(msg) from e

voyageai.api_key = api_key

self.model_name = model_name
self.prefix = prefix
self.suffix = suffix

if batch_size <= MAX_BATCH_SIZE:
self.batch_size = batch_size
else:
err_msg = f"""VoyageDocumentEmbedder has a maximum batch size of {MAX_BATCH_SIZE}. Set the Set the batch_size to {MAX_BATCH_SIZE} or less.""" # noqa
raise ValueError(err_msg)

self.progress_bar = progress_bar
self.metadata_fields_to_embed = metadata_fields_to_embed or []
self.embedding_separator = embedding_separator

def to_dict(self) -> Dict[str, Any]:
"""
This method overrides the default serializer in order to avoid leaking the `api_key` value passed
to the constructor.
"""
return default_to_dict(
self,
model_name=self.model_name,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
metadata_fields_to_embed=self.metadata_fields_to_embed,
embedding_separator=self.embedding_separator,
)

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
"""
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
"""
texts_to_embed = []
for doc in documents:
meta_values_to_embed = [
str(doc.meta[key])
for key in self.metadata_fields_to_embed
if key in doc.meta and doc.meta[key] is not None
]

text_to_embed = (
self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix
)

texts_to_embed.append(text_to_embed)
return texts_to_embed

def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[float]]:
"""
Embed a list of texts in batches.
"""

all_embeddings = []
for i in tqdm(
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
):
batch = texts_to_embed[i : i + batch_size]
embeddings = get_embeddings(list_of_text=batch, batch_size=batch_size, model=self.model_name)
all_embeddings.extend(embeddings)

return all_embeddings

@component.output_types(documents=List[Document])
def run(self, documents: List[Document]):
"""
Embed a list of Documents.
The embedding of each Document is stored in the `embedding` field of the Document.
:param documents: A list of Documents to embed.
"""
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
msg = "VoyageDocumentEmbedder expects a list of Documents as input.In case you want to embed a string, please use the VoyageTextEmbedder." # noqa
raise TypeError(msg)

texts_to_embed = self._prepare_texts_to_embed(documents=documents)

embeddings = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)

for doc, emb in zip(documents, embeddings):
doc.embedding = emb

return {"documents": documents}
80 changes: 80 additions & 0 deletions src/voyage_embedders/voyage_text_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os
from typing import Any, Dict, List, Optional

import voyageai
from haystack.preview import component, default_to_dict
from voyageai import get_embedding


@component
class VoyageTextEmbedder:
"""
A component for embedding strings using Voyage models.
Usage example:
```python
from haystack.preview.components.embedders import VoyageTextEmbedder
text_to_embed = "I love pizza!"
text_embedder = VoyageTextEmbedder()
print(text_embedder.run(text_to_embed))
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
```
"""

def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "voyage-01",
prefix: str = "",
suffix: str = "",
):
"""
Create an VoyageTextEmbedder component.
:param api_key: The VoyageAI API key. It can be explicitly provided or automatically read from the
environment variable VOYAGE_API_KEY (recommended).
:param model_name: The name of the Voyage model to use. Defaults to "voyage-01".
For more details on the available models,
see [Voyage Embeddings documentation](https://docs.voyageai.com/embeddings/).
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
"""
# if the user does not provide the API key, check if it is set in the module client
api_key = api_key or voyageai.api_key
if api_key is None:
try:
api_key = os.environ["VOYAGE_API_KEY"]
except KeyError as e:
msg = "VoyageTextEmbedder expects an VoyageAI API key. Set the VOYAGE_API_KEY environment variable (recommended) or pass it explicitly." # noqa
raise ValueError(msg) from e

voyageai.api_key = api_key

self.model_name = model_name
self.prefix = prefix
self.suffix = suffix

def to_dict(self) -> Dict[str, Any]:
"""
This method overrides the default serializer in order to avoid leaking the `api_key` value passed
to the constructor.
"""

return default_to_dict(self, model_name=self.model_name, prefix=self.prefix, suffix=self.suffix)

@component.output_types(embedding=List[float])
def run(self, text: str):
"""Embed a string."""
if not isinstance(text, str):
msg = "VoyageTextEmbedder expects a string as an input.In case you want to embed a list of Documents, please use the VoyageDocumentEmbedder." # noqa
raise TypeError(msg)

text_to_embed = self.prefix + text + self.suffix

embedding = get_embedding(text=text_to_embed, model=self.model_name)

return {"embedding": embedding}
3 changes: 3 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2023-present Ashwin Mathur <>
#
# SPDX-License-Identifier: Apache-2.0
Loading

0 comments on commit a3c9147

Please sign in to comment.