Skip to content

Commit

Permalink
Fixed tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
souradipp76 committed Nov 8, 2024
1 parent e8e2028 commit 0fb6926
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 66 deletions.
17 changes: 8 additions & 9 deletions doc_generator/utils/HNSWLib.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def init_index(self, vectors: List[List[float]]):
self.args.num_dimensions = len(vectors[0])
self._index = HNSWLib.get_hierarchical_nsw(self.args)
if not self._index.element_count:
print("herer")
self._index.init_index(len(vectors))

def add_vectors(
Expand All @@ -94,9 +95,8 @@ def add_vectors(
raise ValueError("Vectors and documents must have the same length")
if len(vectors[0]) != self.args.num_dimensions:
raise ValueError(
f"Vectors must have the same length as the \
number of dimensions \
({self.args.num_dimensions})"
"Vectors must have the same length as the "
+ f"number of dimensions ({self.args.num_dimensions})"
)
assert self._index is not None
capacity = self._index.get_max_elements()
Expand Down Expand Up @@ -147,16 +147,15 @@ def similarity_search_by_vector(
) -> List:
if len(embedding) != self.args.num_dimensions:
raise ValueError(
f"Query vector must have the same length as the \
number of dimensions \
({self.args.num_dimensions})"
"Query vector must have the same length as the "
+ f"number of dimensions ({self.args.num_dimensions})"
)
assert self._index is not None
total = self._index.element_count
if k > total:
print(
f"k ({k}) is greater than the number of elements in the \
index ({total}), setting k to {total}"
f"k ({k}) is greater than the number of elements in the "
+ f"index ({total}), setting k to {total}"
)
k = total
labels, distances = self._index.knn_query(embedding, k)
Expand All @@ -181,7 +180,7 @@ def save(self, directory: str):
with open(os.path.join(directory, "docstore.json"), "w") as f:
docstore_data = []
for key, val in self.docstore._dict.items():
docstore_data.append([key, val.dict()])
docstore_data.append([key, val.model_dump()])
json.dump(docstore_data, f)
with open(os.path.join(directory, "args.json"), "w") as f:
json.dump(
Expand Down
148 changes: 97 additions & 51 deletions tests/utils/test_HNSWLib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,30 @@
import numpy as np
from unittest.mock import MagicMock, patch

from doc_generator.utils.HNSWLib import HNSWLib, HNSWLibArgs, SaveableVectorStore
from doc_generator.utils.HNSWLib import (
HNSWLib,
HNSWLibArgs,
SaveableVectorStore,
)
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_core.embeddings.embeddings import Embeddings
from langchain_core.documents import Document


def test_init():
embeddings = MagicMock()
args = HNSWLibArgs(space='cosine', num_dimensions=128)
args = HNSWLibArgs(space="cosine", num_dimensions=128)
hnswlib_instance = HNSWLib(embeddings, args)
assert hnswlib_instance._embeddings == embeddings
assert hnswlib_instance.args == args
assert isinstance(hnswlib_instance.docstore, InMemoryDocstore)


@patch('doc_generator.utils.HNSWLib.hnswlib.Index')
@patch("doc_generator.utils.HNSWLib.hnswlib.Index")
def test_add_texts_success(mock_index_class):
embeddings = MagicMock()
embeddings.embed_documents.return_value = [[0.1, 0.2], [0.3, 0.4]]
args = HNSWLibArgs(space='cosine', num_dimensions=2)
args = HNSWLibArgs(space="cosine", num_dimensions=2)
hnswlib_instance = HNSWLib(embeddings, args)

mock_index = MagicMock()
Expand All @@ -49,39 +53,44 @@ def test_get_hierarchical_nsw_no_space():


def test_get_hierarchical_nsw_no_num_dimensions():
args = HNSWLibArgs(space='cosine', num_dimensions=None)
args = HNSWLibArgs(space="cosine", num_dimensions=None)
with pytest.raises(ValueError) as excinfo:
HNSWLib.get_hierarchical_nsw(args)
assert "hnswlib requires a num_dimensions argument" in str(excinfo.value)


def test_add_vectors_mismatched_lengths():
embeddings = MagicMock()
args = HNSWLibArgs(space='cosine', num_dimensions=2)
args = HNSWLibArgs(space="cosine", num_dimensions=2)
hnswlib_instance = HNSWLib(embeddings, args)
vectors = [[0.1, 0.2], [0.3, 0.4]]
documents = [Document("Doc 1")]

with pytest.raises(ValueError) as excinfo:
hnswlib_instance.add_vectors(vectors, documents)
assert "Vectors and documents must have the same length" in str(excinfo.value)
assert "Vectors and documents must have the same length" in str(
excinfo.value
)


def test_add_vectors_wrong_dimension():
embeddings = MagicMock()
args = HNSWLibArgs(space='cosine', num_dimensions=3)
args = HNSWLibArgs(space="cosine", num_dimensions=3)
hnswlib_instance = HNSWLib(embeddings, args)
vectors = [[0.1, 0.2]] # Only 2 dimensions
documents = [Document("Doc 1")]

with pytest.raises(ValueError) as excinfo:
hnswlib_instance.add_vectors(vectors, documents)
assert "Vectors must have the same length as the number of dimensions" in str(excinfo.value)
assert (
"Vectors must have the same length as the number of dimensions"
in str(excinfo.value)
)


def test_add_vectors_resize_index():
embeddings = MagicMock()
args = HNSWLibArgs(space='cosine', num_dimensions=2)
args = HNSWLibArgs(space="cosine", num_dimensions=2)
hnswlib_instance = HNSWLib(embeddings, args)
hnswlib_instance._index = MagicMock()
hnswlib_instance._index.element_count = 100
Expand All @@ -93,13 +102,15 @@ def test_add_vectors_resize_index():

hnswlib_instance.add_vectors(vectors, documents)

hnswlib_instance._index.resize_index.assert_called_with(102) # 100 existing + 2 new
hnswlib_instance._index.resize_index.assert_called_with(
102
) # 100 existing + 2 new


def test_add_documents():
embeddings = MagicMock()
embeddings.embed_documents.return_value = [[0.1, 0.2], [0.3, 0.4]]
args = HNSWLibArgs(space='cosine', num_dimensions=2)
args = HNSWLibArgs(space="cosine", num_dimensions=2)
hnswlib_instance = HNSWLib(embeddings, args)
hnswlib_instance.add_vectors = MagicMock()
documents = [Document("Doc 1"), Document("Doc 2")]
Expand All @@ -113,7 +124,9 @@ def test_add_documents():
def test_from_texts():
embeddings = MagicMock()
embeddings.embed_documents.return_value = [[0.1, 0.2], [0.3, 0.4]]
with patch('doc_generator.utils.HNSWLib.HNSWLib.from_documents') as mock_from_documents:
with patch(
"doc_generator.utils.HNSWLib.HNSWLib.from_documents"
) as mock_from_documents:
HNSWLib.from_texts(["Text 1", "Text 2"], embeddings)
mock_from_documents.assert_called()

Expand All @@ -122,7 +135,9 @@ def test_from_documents_with_docstore():
embeddings = MagicMock()
documents = [Document("Doc 1"), Document("Doc 2")]
docstore = MagicMock()
with patch('doc_generator.utils.HNSWLib.HNSWLib.add_documents') as mock_add_documents:
with patch(
"doc_generator.utils.HNSWLib.HNSWLib.add_documents"
) as mock_add_documents:
hnsw = HNSWLib.from_documents(documents, embeddings, docstore=docstore)
mock_add_documents.assert_called_once_with(documents)

Expand All @@ -136,90 +151,120 @@ def test_from_documents_without_docstore():

def test_similarity_search_by_vector_wrong_dimension():
embeddings = MagicMock()
args = HNSWLibArgs(space='cosine', num_dimensions=3)
args = HNSWLibArgs(space="cosine", num_dimensions=3)
hnswlib_instance = HNSWLib(embeddings, args)
hnswlib_instance._index = MagicMock()
embedding = [0.1, 0.2] # Only 2 dimensions

with pytest.raises(ValueError) as excinfo:
hnswlib_instance.similarity_search_by_vector(embedding)
assert f"Query vector must have the same length as the number of dimensions ({args.num_dimensions})" in str(excinfo.value)
assert (
f"Query vector must have the same length as the number of dimensions ({args.num_dimensions})"
in str(excinfo.value)
)


def test_similarity_search_by_vector_k_greater_than_total(capsys):
embeddings = MagicMock()
args = HNSWLibArgs(space='cosine', num_dimensions=2)
args = HNSWLibArgs(space="cosine", num_dimensions=2)
hnswlib_instance = HNSWLib(embeddings, args)
hnswlib_instance._index = MagicMock()
hnswlib_instance._index.element_count = 1 # Total elements is 1
hnswlib_instance._index.knn_query.return_value = (np.array([[0]]), np.array([[0.0]]))
hnswlib_instance.docstore._dict = {'0': Document("Doc 0")}
hnswlib_instance._index.knn_query.return_value = (
np.array([[0]]),
np.array([[0.0]]),
)
hnswlib_instance.docstore._dict = {"0": Document("Doc 0")}

embedding = [0.1, 0.2]
hnswlib_instance.similarity_search_by_vector(embedding, k=5)
captured = capsys.readouterr()
assert "k (5) is greater than the number of elements in the index (1), setting k to 1" in captured.out
assert (
"k (5) is greater than the number of elements in the index (1), setting k to 1"
in captured.out
)


def test_similarity_search():
embeddings = MagicMock()
embeddings.embed_query.return_value = [0.1, 0.2]
args = HNSWLibArgs(space='cosine', num_dimensions=2)
args = HNSWLibArgs(space="cosine", num_dimensions=2)
hnswlib_instance = HNSWLib(embeddings, args)
hnswlib_instance.similarity_search_by_vector = MagicMock(return_value=[(Document("Doc 1"), 0.0)])
hnswlib_instance.similarity_search_by_vector = MagicMock(
return_value=[(Document("Doc 1"), 0.0)]
)

results = hnswlib_instance.similarity_search("query", k=2)
embeddings.embed_query.assert_called_with("query")
hnswlib_instance.similarity_search_by_vector.assert_called_with([0.1, 0.2], k=2)
hnswlib_instance.similarity_search_by_vector.assert_called_with(
[0.1, 0.2], 2
)
assert len(results) == 1
assert results[0].page_content == "Doc 1"


def test_save():
embeddings = MagicMock()
args = HNSWLibArgs(space='cosine', num_dimensions=2)
args = HNSWLibArgs(space="cosine", num_dimensions=2)
hnswlib_instance = HNSWLib(embeddings, args)
hnswlib_instance._index = MagicMock()
hnswlib_instance._index.save_index.return_value = None
hnswlib_instance.docstore = InMemoryDocstore()
hnswlib_instance.docstore._dict = {'1': Document("Doc 1"), '2': Document("Doc 2")}

with patch('builtins.open', new_callable=MagicMock()) as mock_open:
with patch('os.path.exists') as mock_exists:
with patch('os.makedirs') as mock_makedirs:
hnswlib_instance.docstore._dict = {
"1": Document("Doc 1"),
"2": Document("Doc 2"),
}

with patch("builtins.open", new_callable=MagicMock()) as mock_open:
with patch("os.path.exists") as mock_exists:
with patch("os.makedirs") as mock_makedirs:
mock_exists.return_value = False
hnswlib_instance.save('test_directory')
mock_makedirs.assert_called_with('test_directory')
hnswlib_instance._index.save_index.assert_called_with(os.path.join('test_directory', 'hnswlib.index'))
assert mock_open.call_count == 2 # For docstore.json and args.json
hnswlib_instance.save("test_directory")
mock_makedirs.assert_called_with("test_directory")
hnswlib_instance._index.save_index.assert_called_with(
os.path.join("test_directory", "hnswlib.index")
)
assert (
mock_open.call_count == 2
) # For docstore.json and args.json


def test_load():
embeddings = MagicMock()
with patch('builtins.open', new_callable=MagicMock()) as mock_open:
with patch('json.load') as mock_json_load:
with patch("builtins.open", new_callable=MagicMock()) as mock_open:
with patch("json.load") as mock_json_load:
mock_json_load.side_effect = [
{'space': 'cosine', 'num_dimensions': 2}, # For args.json
[['1', {'page_content': 'Doc 1', 'metadata': {}}], ['2', {'page_content': 'Doc 2', 'metadata': {}}]], # For docstore.json
{"space": "cosine", "num_dimensions": 2}, # For args.json
[
["1", {"page_content": "Doc 1", "metadata": {}}],
["2", {"page_content": "Doc 2", "metadata": {}}],
], # For docstore.json
]
with patch('doc_generator.utils.HNSWLib.hnswlib.Index') as mock_index_class:
with patch(
"doc_generator.utils.HNSWLib.hnswlib.Index"
) as mock_index_class:
mock_index = MagicMock()
mock_index.load_index.return_value = None
mock_index_class.return_value = mock_index
hnswlib_instance = HNSWLib.load('test_directory', embeddings)
assert hnswlib_instance.args.space == 'cosine'
hnswlib_instance = HNSWLib.load("test_directory", embeddings)
assert hnswlib_instance.args.space == "cosine"
assert hnswlib_instance.args.num_dimensions == 2
assert hnswlib_instance._index == mock_index
assert '1' in hnswlib_instance.docstore._dict
assert hnswlib_instance.docstore._dict['1'].page_content == 'Doc 1'
assert "1" in hnswlib_instance.docstore._dict
assert (
hnswlib_instance.docstore._dict["1"].page_content
== "Doc 1"
)


def test_init_index_no_index():
embeddings = MagicMock()
args = HNSWLibArgs(space='cosine', num_dimensions=None)
args = HNSWLibArgs(space="cosine", num_dimensions=None)
hnswlib_instance = HNSWLib(embeddings, args)
vectors = [[0.1, 0.2, 0.3]]
with patch('doc_generator.utils.HNSWLib.HNSWLib.get_hierarchical_nsw') as mock_get_hnsw:
with patch(
"doc_generator.utils.HNSWLib.HNSWLib.get_hierarchical_nsw"
) as mock_get_hnsw:
mock_get_hnsw.return_value = MagicMock()
hnswlib_instance.init_index(vectors)
assert hnswlib_instance.args.num_dimensions == 3
Expand All @@ -228,8 +273,9 @@ def test_init_index_no_index():

def test_init_index_with_index():
embeddings = MagicMock()
args = HNSWLibArgs(space='cosine', num_dimensions=3)
args = HNSWLibArgs(space="cosine", num_dimensions=3)
mock_index = MagicMock()
mock_index.element_count = None
args.index = mock_index
hnswlib_instance = HNSWLib(embeddings, args)
hnswlib_instance._index = mock_index
Expand All @@ -240,26 +286,26 @@ def test_init_index_with_index():

def test_save_directory_exists():
embeddings = MagicMock()
args = HNSWLibArgs(space='cosine', num_dimensions=2)
args = HNSWLibArgs(space="cosine", num_dimensions=2)
hnswlib_instance = HNSWLib(embeddings, args)
hnswlib_instance._index = MagicMock()
hnswlib_instance.docstore = MagicMock()
hnswlib_instance.docstore._dict = {}

with patch('os.path.exists') as mock_exists:
with patch('os.makedirs') as mock_makedirs:
with patch("os.path.exists") as mock_exists:
with patch("os.makedirs") as mock_makedirs:
mock_exists.return_value = True
hnswlib_instance.save('.')
hnswlib_instance.save(".")
mock_makedirs.assert_not_called()


def test_runtime_error_caught(capsys):
embeddings = MagicMock()
args = HNSWLibArgs(space='cosine', num_dimensions=2)
args = HNSWLibArgs(space="cosine", num_dimensions=2)
hnswlib_instance = HNSWLib(embeddings, args)
hnswlib_instance._index = None

with pytest.raises(AttributeError):
hnswlib_instance._index.save_index('some_path')
hnswlib_instance._index.save_index("some_path")
# Since there's no exception handling in the save method, we can't catch a RuntimeError here.
# This test is just to show that if an error occurs, it will propagate.
Loading

0 comments on commit 0fb6926

Please sign in to comment.