Skip to content

Commit c6456bb

Browse files
authored
feat: Adding VoyageAI reranker (#31)
* Adding VoyageAI reranker * Adding VoyageAI reranker * Refactoring due to the comments * Correcting the docs
1 parent 64d25b8 commit c6456bb

File tree

4 files changed

+385
-0
lines changed

4 files changed

+385
-0
lines changed

examples/reranker_example.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from haystack import Document
2+
from haystack_integrations.components.rankers.voyage.voyage_text_reranker import VoyageRanker
3+
4+
ranker = VoyageRanker(model="rerank-2", top_k=2)
5+
6+
docs = [Document(content="Paris"), Document(content="Berlin")]
7+
query = "What is the capital of germany?"
8+
output = ranker.run(query=query, documents=docs)
9+
docs = output["documents"]
10+
11+
for doc in docs:
12+
print(f"{doc.content} - {doc.score}")
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from haystack_integrations.components.rankers.voyage.ranker import VoyageRanker
2+
3+
__all__ = ["VoyageRanker"]
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import os
2+
from typing import Any, Dict, List, Optional
3+
4+
from haystack import component, default_from_dict, default_to_dict, Document, logging
5+
from haystack.utils import Secret, deserialize_secrets_inplace
6+
from voyageai import Client
7+
8+
logger = logging.getLogger(__name__)
9+
10+
MAX_NUM_DOCS = 1000
11+
12+
13+
@component
14+
class VoyageRanker:
15+
"""
16+
A component for reranking using Voyage models.
17+
18+
Usage example:
19+
```python
20+
from haystack import Document
21+
from haystack_integrations.components.rankers.voyage.ranker import VoyageRanker
22+
23+
ranker = VoyageRanker(model="rerank-2", top_k=2)
24+
25+
docs = [Document(content="Paris"), Document(content="Berlin")]
26+
query = "What is the capital of germany?"
27+
output = ranker.run(query=query, documents=docs)
28+
docs = output["documents"]
29+
```
30+
"""
31+
32+
def __init__(
33+
self,
34+
api_key: Secret = Secret.from_env_var("VOYAGE_API_KEY"),
35+
model: str = "rerank-2",
36+
truncate: Optional[bool] = None,
37+
top_k: Optional[int] = None,
38+
prefix: str = "",
39+
suffix: str = "",
40+
timeout: Optional[int] = None,
41+
max_retries: Optional[int] = None,
42+
meta_fields_to_embed: Optional[List[str]] = None,
43+
meta_data_separator: str = "\n",
44+
):
45+
"""
46+
Create an VoyageRanker component.
47+
48+
:param api_key:
49+
The VoyageAI API key. It can be explicitly provided or automatically read from the environment variable
50+
VOYAGE_API_KEY (recommended).
51+
:param model:
52+
The name of the Voyage model to use. Defaults to "voyage-2".
53+
For more details on the available models,
54+
see [Voyage Rerankers documentation](https://docs.voyageai.com/docs/reranker).
55+
:param truncate:
56+
Whether to truncate the input texts to fit within the context length.
57+
- If `True`, over-length input texts will be truncated to fit within the context length, before vectorized
58+
by the reranker model.
59+
- If False, an error will be raised if any given text exceeds the context length.
60+
- Defaults to `None`, which will truncate the input text before sending it to the reranker model if it
61+
slightly exceeds the context window length. If it significantly exceeds the context window length, an
62+
error will be raised.
63+
:param top_k:
64+
The number of most relevant documents to return.
65+
If not specified, the reranking results of all documents will be returned.
66+
:param prefix:
67+
A string to add to the beginning of each text.
68+
:param suffix:
69+
A string to add to the end of each text.
70+
:param timeout:
71+
Timeout for VoyageAI Client calls, if not set it is inferred from the `VOYAGE_TIMEOUT` environment variable
72+
or set to 30.
73+
:param max_retries:
74+
Maximum retries to establish contact with VoyageAI if it returns an internal error, if not set it is
75+
inferred from the `VOYAGE_MAX_RETRIES` environment variable or set to 5.
76+
"""
77+
self.api_key = api_key
78+
self.model = model
79+
self.top_k = top_k
80+
self.truncate = truncate
81+
self.prefix = prefix
82+
self.suffix = suffix
83+
self.meta_fields_to_embed = meta_fields_to_embed or []
84+
self.meta_data_separator = meta_data_separator
85+
86+
if timeout is None:
87+
timeout = int(os.environ.get("VOYAGE_TIMEOUT", 30))
88+
if max_retries is None:
89+
max_retries = int(os.environ.get("VOYAGE_MAX_RETRIES", 5))
90+
91+
self.client = Client(api_key=api_key.resolve_value(), max_retries=max_retries, timeout=timeout)
92+
93+
def to_dict(self) -> Dict[str, Any]:
94+
"""
95+
Serializes the component to a dictionary.
96+
97+
:returns:
98+
Dictionary with serialized data.
99+
"""
100+
return default_to_dict(
101+
self,
102+
model=self.model,
103+
top_k=self.top_k,
104+
truncate=self.truncate,
105+
prefix=self.prefix,
106+
suffix=self.suffix,
107+
api_key=self.api_key.to_dict(),
108+
meta_fields_to_embed=self.meta_fields_to_embed,
109+
meta_data_separator=self.meta_data_separator,
110+
)
111+
112+
@classmethod
113+
def from_dict(cls, data: Dict[str, Any]) -> "VoyageRanker":
114+
"""
115+
Deserializes the component from a dictionary.
116+
117+
:param data:
118+
Dictionary to deserialize from.
119+
:returns:
120+
Deserialized component.
121+
"""
122+
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
123+
return default_from_dict(cls, data)
124+
125+
def _prepare_input_docs(self, documents: List[Document]) -> List[str]:
126+
"""
127+
Prepare the input by concatenating the document text with the metadata fields specified.
128+
:param documents:
129+
The list of Document objects.
130+
131+
:return:
132+
A list of strings to be given as input to Voyage AI model.
133+
"""
134+
concatenated_input_list = []
135+
for doc in documents:
136+
meta_values_to_embed = [
137+
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta.get(key)
138+
]
139+
concatenated_input = self.meta_data_separator.join([*meta_values_to_embed, doc.content or ""])
140+
concatenated_input_list.append(concatenated_input)
141+
142+
return concatenated_input_list
143+
144+
@component.output_types(documents=List[Document])
145+
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None):
146+
"""
147+
Use the Voyage AI Reranker to re-rank the list of documents based on the query.
148+
149+
:param query:
150+
Query string.
151+
:param documents:
152+
List of Documents.
153+
:param top_k:
154+
The maximum number of Documents you want the Ranker to return.
155+
:returns:
156+
A dictionary with the following keys:
157+
- `documents`: List of Documents most similar to the given query in descending order of similarity.
158+
159+
:raises ValueError: If `top_k` is not > 0.
160+
"""
161+
top_k = top_k or self.top_k
162+
if top_k is not None and top_k <= 0:
163+
msg = f"top_k must be > 0, but got {top_k}"
164+
raise ValueError(msg)
165+
166+
input_docs = self._prepare_input_docs(documents)
167+
if len(input_docs) > MAX_NUM_DOCS:
168+
logger.warning(
169+
f"The Voyage AI reranking endpoint only supports {MAX_NUM_DOCS} documents.\
170+
The number of documents has been truncated to {MAX_NUM_DOCS} \
171+
from {len(input_docs)}."
172+
)
173+
input_docs = input_docs[:MAX_NUM_DOCS]
174+
175+
response = self.client.rerank(
176+
model=self.model,
177+
query=query,
178+
documents=input_docs,
179+
top_k=top_k,
180+
)
181+
indices = [output.index for output in response.results]
182+
scores = [output.relevance_score for output in response.results]
183+
sorted_docs = []
184+
for idx, score in zip(indices, scores):
185+
doc = documents[idx]
186+
doc.score = score
187+
sorted_docs.append(documents[idx])
188+
return {"documents": sorted_docs}

tests/test_ranker.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import os
2+
3+
import pytest
4+
from haystack import Document
5+
from haystack.utils.auth import Secret
6+
from voyageai.error import InvalidRequestError
7+
8+
from haystack_integrations.components.rankers.voyage import VoyageRanker
9+
10+
11+
class TestVoyageTextReranker:
12+
@pytest.mark.unit
13+
def test_init_default(self, monkeypatch):
14+
monkeypatch.setenv("VOYAGE_API_KEY", "fake-api-key")
15+
reranker = VoyageRanker()
16+
17+
assert reranker.client.api_key == "fake-api-key"
18+
assert reranker.model == "rerank-2"
19+
assert reranker.truncate is None
20+
assert reranker.prefix == ""
21+
assert reranker.suffix == ""
22+
assert reranker.top_k is None
23+
assert reranker.meta_fields_to_embed == []
24+
assert reranker.meta_data_separator == "\n"
25+
26+
@pytest.mark.unit
27+
def test_init_with_parameters(self):
28+
reranker = VoyageRanker(
29+
api_key=Secret.from_token("fake-api-key"),
30+
model="model",
31+
truncate=True,
32+
top_k=10,
33+
prefix="prefix",
34+
suffix="suffix",
35+
meta_fields_to_embed=["meta_field_1", "meta_field_2"],
36+
meta_data_separator=",",
37+
)
38+
assert reranker.client.api_key == "fake-api-key"
39+
assert reranker.model == "model"
40+
assert reranker.truncate is True
41+
assert reranker.top_k == 10
42+
assert reranker.prefix == "prefix"
43+
assert reranker.suffix == "suffix"
44+
assert reranker.meta_fields_to_embed == ["meta_field_1", "meta_field_2"]
45+
assert reranker.meta_data_separator == ","
46+
47+
@pytest.mark.unit
48+
def test_init_fail_wo_api_key(self, monkeypatch):
49+
monkeypatch.delenv("VOYAGE_API_KEY", raising=False)
50+
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
51+
VoyageRanker()
52+
53+
@pytest.mark.unit
54+
def test_to_dict(self, monkeypatch):
55+
monkeypatch.setenv("VOYAGE_API_KEY", "fake-api-key")
56+
component = VoyageRanker()
57+
data = component.to_dict()
58+
assert data == {
59+
"type": "haystack_integrations.components.rankers.voyage.ranker."
60+
"VoyageRanker",
61+
"init_parameters": {
62+
"api_key": {"env_vars": ["VOYAGE_API_KEY"], "strict": True, "type": "env_var"},
63+
"model": "rerank-2",
64+
"truncate": None,
65+
"top_k": None,
66+
"prefix": "",
67+
"suffix": "",
68+
"meta_fields_to_embed": [],
69+
"meta_data_separator": "\n"
70+
},
71+
}
72+
73+
@pytest.mark.unit
74+
def test_from_dict(self, monkeypatch):
75+
monkeypatch.setenv("VOYAGE_API_KEY", "fake-api-key")
76+
data = {
77+
"type": "haystack_integrations.components.rankers.voyage.ranker."
78+
"VoyageRanker",
79+
"init_parameters": {
80+
"api_key": {"env_vars": ["VOYAGE_API_KEY"], "strict": True, "type": "env_var"},
81+
"model": "rerank-2",
82+
"truncate": None,
83+
"top_k": 10,
84+
"prefix": "",
85+
"suffix": "",
86+
"meta_fields_to_embed": None,
87+
"meta_data_separator": "\n"
88+
},
89+
}
90+
91+
reranker = VoyageRanker.from_dict(data)
92+
assert reranker.client.api_key == "fake-api-key"
93+
assert reranker.top_k == 10
94+
assert reranker.model == "rerank-2"
95+
assert reranker.truncate is None
96+
assert reranker.prefix == ""
97+
assert reranker.suffix == ""
98+
assert reranker.meta_fields_to_embed == []
99+
assert reranker.meta_data_separator == '\n'
100+
101+
@pytest.mark.unit
102+
def test_to_dict_with_custom_init_parameters(self, monkeypatch):
103+
monkeypatch.setenv("ENV_VAR", "fake-api-key")
104+
component = VoyageRanker(
105+
api_key=Secret.from_env_var("ENV_VAR", strict=False),
106+
model="model",
107+
truncate=True,
108+
top_k=10,
109+
prefix="prefix",
110+
suffix="suffix",
111+
)
112+
data = component.to_dict()
113+
assert data == {
114+
"type": "haystack_integrations.components.rankers.voyage.ranker."
115+
"VoyageRanker",
116+
"init_parameters": {
117+
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
118+
"model": "model",
119+
"truncate": True,
120+
"top_k": 10,
121+
"prefix": "prefix",
122+
"suffix": "suffix",
123+
'meta_data_separator': '\n',
124+
'meta_fields_to_embed': [],
125+
126+
},
127+
}
128+
129+
@pytest.mark.unit
130+
def test_from_dict_with_custom_init_parameters(self, monkeypatch):
131+
monkeypatch.setenv("ENV_VAR", "fake-api-key")
132+
data = {
133+
"type": "haystack_integrations.components.rankers.voyage.ranker."
134+
"VoyageRanker",
135+
"init_parameters": {
136+
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
137+
"model": "model",
138+
"truncate": True,
139+
"top_k": 10,
140+
"prefix": "prefix",
141+
"suffix": "suffix",
142+
},
143+
}
144+
145+
reranker = VoyageRanker.from_dict(data)
146+
assert reranker.client.api_key == "fake-api-key"
147+
assert reranker.model == "model"
148+
assert reranker.truncate is True
149+
assert reranker.top_k == 10
150+
assert reranker.prefix == "prefix"
151+
assert reranker.suffix == "suffix"
152+
153+
@pytest.mark.skipif(os.environ.get("VOYAGE_API_KEY", "") == "", reason="VOYAGE_API_KEY is not set")
154+
@pytest.mark.integration
155+
def test_run(self):
156+
model = "rerank-2"
157+
158+
documents = [
159+
Document(id="abcd", content="Paris is in France"),
160+
Document(id="efgh", content="Berlin is in Germany"),
161+
Document(id="ijkl", content="Lyon is in France"),
162+
]
163+
164+
reranker = VoyageRanker(model=model, prefix="prefix ", suffix=" suffix")
165+
result = reranker.run(query="The food was delicious", documents=documents, top_k=2)
166+
167+
assert len(result["documents"]) == 2
168+
assert all(isinstance(x, Document) for x in result["documents"])
169+
170+
@pytest.mark.unit
171+
def test_run_wrong_input_format(self):
172+
reranker = VoyageRanker(api_key=Secret.from_token("fake-api-key"))
173+
174+
integer_input = 1
175+
documents = [
176+
Document(id="abcd", content="Paris is in France"),
177+
Document(id="efgh", content="Berlin is in Germany"),
178+
Document(id="ijkl", content="Lyon is in France"),
179+
]
180+
181+
with pytest.raises(InvalidRequestError, match=f"not a valid string"):
182+
reranker.run(query=integer_input, documents=documents)

0 commit comments

Comments
 (0)