Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
awinml committed Dec 11, 2023
1 parent e2bedf7 commit 61f852c
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/voyage_embedders/voyage_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from typing import Any, Dict, List, Optional

import voyageai
from haystack.dataclasses import Document
from haystack.core.component import component
from haystack.core.serialization import default_to_dict
from haystack.dataclasses import Document
from tqdm import tqdm
from voyageai import get_embeddings

Expand Down
5 changes: 3 additions & 2 deletions tests/test_voyage_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_to_dict(self):
component = VoyageDocumentEmbedder(api_key="fake-api-key")
data = component.to_dict()
assert data == {
"type": "VoyageDocumentEmbedder",
"type": "voyage_embedders.voyage_document_embedder.VoyageDocumentEmbedder",
"init_parameters": {
"model_name": "voyage-01",
"prefix": "",
Expand All @@ -91,7 +91,7 @@ def test_to_dict_with_custom_init_parameters(self):
)
data = component.to_dict()
assert data == {
"type": "VoyageDocumentEmbedder",
"type": "voyage_embedders.voyage_document_embedder.VoyageDocumentEmbedder",
"init_parameters": {
"model_name": "model",
"prefix": "prefix",
Expand Down Expand Up @@ -187,6 +187,7 @@ def test_run(self):
"prefix ML | A transformer is a deep learning architecture suffix",
],
batch_size=8,
input_type="document",
)
documents_with_embeddings = result["documents"]

Expand Down
8 changes: 5 additions & 3 deletions tests/test_voyage_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_to_dict(self):
component = VoyageTextEmbedder(api_key="fake-api-key")
data = component.to_dict()
assert data == {
"type": "VoyageTextEmbedder",
"type": "voyage_embedders.voyage_text_embedder.VoyageTextEmbedder",
"init_parameters": {
"model_name": "voyage-01",
"input_type": "query",
Expand All @@ -73,7 +73,7 @@ def test_to_dict_with_custom_init_parameters(self):
)
data = component.to_dict()
assert data == {
"type": "VoyageTextEmbedder",
"type": "voyage_embedders.voyage_text_embedder.VoyageTextEmbedder",
"init_parameters": {
"model_name": "model",
"input_type": "document",
Expand All @@ -92,7 +92,9 @@ def test_run(self):
embedder = VoyageTextEmbedder(api_key="fake-api-key", model_name=model, prefix="prefix ", suffix=" suffix")
result = embedder.run(text="The food was delicious")

voyageai_embedding_patch.assert_called_once_with(model=model, text="prefix The food was delicious suffix")
voyageai_embedding_patch.assert_called_once_with(
model=model, text="prefix The food was delicious suffix", input_type="query"
)

assert len(result["embedding"]) == 1024
assert all(isinstance(x, float) for x in result["embedding"])
Expand Down

0 comments on commit 61f852c

Please sign in to comment.