Skip to content

Commit

Permalink
Fix a bug with excess RAM usage during vector computes. (#1195)
Browse files Browse the repository at this point in the history
  • Loading branch information
nsthorat authored Feb 29, 2024
1 parent f74d88d commit 5d84c76
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 30 deletions.
4 changes: 3 additions & 1 deletion lilac/data/dataset_select_rows_sort_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,9 @@ def vector_compute(
self, all_vector_spans: Iterable[list[SpanVector]]
) -> Iterator[Optional[Item]]:
for vector_spans in all_vector_spans:
embeddings = np.array([vector_span['vector'] for vector_span in vector_spans])
embeddings = np.array([vector_span['vector'] for vector_span in vector_spans]).reshape(
len(vector_spans), -1
)
scores = embeddings.dot(self._query).reshape(-1)
res: Item = []
for vector_span, score in zip(vector_spans, scores):
Expand Down
9 changes: 3 additions & 6 deletions lilac/embeddings/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import abc
import os
import pickle
from typing import Iterable, Optional, Sequence, Type, cast
from typing import Iterable, Iterator, Optional, Sequence, Type, cast

import numpy as np

Expand Down Expand Up @@ -50,7 +50,7 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
pass

@abc.abstractmethod
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray:
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]:
"""Return the embeddings for given keys.
Args:
Expand Down Expand Up @@ -159,13 +159,10 @@ def get(self, keys: Iterable[PathKey]) -> Iterable[list[SpanVector]]:
all_spans.append(spans)
all_vector_keys.append([(*path_key, i) for i in range(len(spans))])

offset = 0
flat_vector_keys = [key for vector_keys in all_vector_keys for key in (vector_keys or [])]
all_vectors = self._vector_store.get(flat_vector_keys)
for spans in all_spans:
vectors = all_vectors[offset : offset + len(spans)]
yield [{'span': span, 'vector': vector} for span, vector in zip(spans, vectors)]
offset += len(spans)
yield [{'span': span, 'vector': next(all_vectors)} for span in spans]

def topk(
self, query: np.ndarray, k: int, rowids: Optional[Iterable[str]] = None
Expand Down
19 changes: 13 additions & 6 deletions lilac/embeddings/vector_store_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import multiprocessing
import os
import threading
from typing import Iterable, Optional, Set, cast
from typing import Iterable, Iterator, Optional, Set, cast

import hnswlib
import numpy as np
import pandas as pd
from typing_extensions import override

from ..schema import VectorKey
from ..utils import DebugTimer
from ..utils import DebugTimer, chunks
from .vector_store import VectorStore

_HNSW_SUFFIX = '.hnswlib.bin'
Expand All @@ -22,6 +22,8 @@
CONSTRUCTION_EF = 100
M = 16
SPACE = 'ip'
# The number of items to retrieve at a time given a query of keys.
HNSW_RETRIEVAL_BATCH_SIZE = 1024


class HNSWVectorStore(VectorStore):
Expand Down Expand Up @@ -105,15 +107,20 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
self._index.set_ef(min(QUERY_EF, self.size()))

@override
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray:
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]:
assert (
self._index is not None and self._key_to_label is not None
), 'No embeddings exist in this store.'
with self._lock:
if not keys:
return np.array(self._index.get_items(self._key_to_label.values), dtype=np.float32)
locs = self._key_to_label.loc[cast(list[str], keys)].values
return np.array(self._index.get_items(locs), dtype=np.float32)
locs = self._key_to_label.values
else:
locs = self._key_to_label.loc[cast(list[str], keys)].values

for loc_chunk in chunks(locs, HNSW_RETRIEVAL_BATCH_SIZE):
chunk_items = np.array(self._index.get_items(loc_chunk), dtype=np.float32)
for vector in np.split(chunk_items, chunk_items.shape[0]):
yield np.squeeze(vector)

@override
def topk(
Expand Down
14 changes: 9 additions & 5 deletions lilac/embeddings/vector_store_numpy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""NumpyVectorStore class for storing vectors in numpy arrays."""

import os
from typing import Iterable, Optional, cast
from typing import Iterable, Iterator, Optional, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -73,14 +73,18 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
self._key_to_index = new_key_to_label

@override
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray:
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]:
assert (
self._embeddings is not None and self._key_to_index is not None
), 'The vector store has no embeddings. Call load() or add() first.'
if not keys:
return self._embeddings
locs = self._key_to_index.loc[cast(list[str], keys)]
return self._embeddings.take(locs, axis=0)
embeddings = self._embeddings
else:
locs = self._key_to_index.loc[cast(list[str], keys)]
embeddings = self._embeddings.take(locs, axis=0)

for vector in np.split(embeddings, embeddings.shape[0]):
yield np.squeeze(vector)

@override
def topk(
Expand Down
29 changes: 19 additions & 10 deletions lilac/embeddings/vector_store_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,32 @@ def test_add_chunks(self, store_cls: Type[VectorStore]) -> None:
store.add([('a',), ('b',)], np.array([[1, 2], [3, 4]]))
store.add([('c',)], np.array([[5, 6]]))

np.testing.assert_array_equal(
store.get([('a',), ('b',), ('c',)]), np.array([[1, 2], [3, 4], [5, 6]])
)
vectors = list(store.get([('a',), ('b',), ('c',)]))
assert len(vectors) == 3
np.testing.assert_array_equal(vectors[0], [1, 2])
np.testing.assert_array_equal(vectors[1], [3, 4])
np.testing.assert_array_equal(vectors[2], [5, 6])

def test_get_all(self, store_cls: Type[VectorStore]) -> None:
store = store_cls()

store.add([('a',), ('b',), ('c',)], np.array([[1, 2], [3, 4], [5, 6]]))

np.testing.assert_array_equal(
store.get([('a',), ('b',), ('c',)]), np.array([[1, 2], [3, 4], [5, 6]])
)
vectors = list(store.get([('a',), ('b',), ('c',)]))
assert len(vectors) == 3
np.testing.assert_array_equal(vectors[0], [1, 2])
np.testing.assert_array_equal(vectors[1], [3, 4])
np.testing.assert_array_equal(vectors[2], [5, 6])

def test_get_subset(self, store_cls: Type[VectorStore]) -> None:
store = store_cls()

store.add([('a',), ('b',), ('c',)], np.array([[1, 2], [3, 4], [5, 6]]))

np.testing.assert_array_equal(store.get([('b',), ('c',)]), np.array([[3, 4], [5, 6]]))
vectors = list(store.get([('b',), ('c',)]))
assert len(vectors) == 2
np.testing.assert_array_equal(vectors[0], [3, 4])
np.testing.assert_array_equal(vectors[1], [5, 6])

def test_save_load(self, store_cls: Type[VectorStore], tmp_path: pathlib.Path) -> None:
store = store_cls()
Expand All @@ -54,9 +61,11 @@ def test_save_load(self, store_cls: Type[VectorStore], tmp_path: pathlib.Path) -
store = store_cls()
store.load((str(tmp_path)))

np.testing.assert_array_equal(
store.get([('a',), ('b',), ('c',)]), np.array([[1, 2], [3, 4], [5, 6]])
)
vectors = list(store.get([('a',), ('b',), ('c',)]))
assert len(vectors) == 3
np.testing.assert_array_equal(vectors[0], [1, 2])
np.testing.assert_array_equal(vectors[1], [3, 4])
np.testing.assert_array_equal(vectors[2], [5, 6])

def test_topk(self, store_cls: Type[VectorStore]) -> None:
store = store_cls()
Expand Down
6 changes: 4 additions & 2 deletions lilac/signals/semantic_similarity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
pass

@override
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray:
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]:
keys = keys or []
return np.array([EMBEDDINGS[tuple(path_key)][cast(int, index)] for *path_key, index in keys])
yield from [
np.array(EMBEDDINGS[tuple(path_key)][cast(int, index)]) for *path_key, index in keys
]

@override
def delete(self, base_path: str) -> None:
Expand Down

0 comments on commit 5d84c76

Please sign in to comment.