Skip to content

Commit d50d75c

Browse files
committed
add dedicated ModelIndex class
1 parent 2e5f77e commit d50d75c

File tree

2 files changed

+260
-55
lines changed

2 files changed

+260
-55
lines changed

ragatouille/models/colbert.py

Lines changed: 47 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from colbert.modeling.checkpoint import Checkpoint
1414

1515
from ragatouille.models.base import LateInteractionModel
16+
from ragatouille.models.index import ModelIndex, ModelIndexFactory
1617

1718
# TODO: Move all bsize related calcs to `_set_bsize()`
1819

@@ -40,6 +41,7 @@ def __init__(
4041

4142
self.loaded_from_index = load_from_index
4243

44+
self.model_index: Optional[ModelIndex] = None
4345
if load_from_index:
4446
self.index_path = str(pretrained_model_name_or_path)
4547
ckpt_config = ColBERTConfig.load_from_index(
@@ -299,6 +301,20 @@ def delete_from_index(
299301

300302
print(f"Successfully deleted documents with these IDs: {document_ids}")
301303

304+
def _save_index_metadata(self):
305+
self._write_collection_to_file(
306+
self.collection, self.index_path + "/collection.json"
307+
)
308+
309+
self._write_collection_to_file(
310+
self.pid_docid_map, self.index_path + "/pid_docid_map.json"
311+
)
312+
313+
if self.docid_metadata_map is not None:
314+
self._write_collection_to_file(
315+
self.docid_metadata_map, self.index_path + "/docid_metadata_map.json"
316+
)
317+
302318
def index(
303319
self,
304320
collection: List[str],
@@ -309,21 +325,9 @@ def index(
309325
overwrite: Union[bool, str] = "reuse",
310326
bsize: int = 32,
311327
):
312-
if torch.cuda.is_available():
313-
import faiss
314-
315-
if not hasattr(faiss, "StandardGpuResources"):
316-
print(
317-
"________________________________________________________________________________\n"
318-
"WARNING! You have a GPU available, but only `faiss-cpu` is currently installed.\n",
319-
"This means that indexing will be slow. To make use of your GPU.\n"
320-
"Please install `faiss-gpu` by running:\n"
321-
"pip uninstall --y faiss-cpu & pip install faiss-gpu\n",
322-
"________________________________________________________________________________",
323-
)
324-
print("Will continue with CPU indexing in 5 seconds...")
325-
time.sleep(5)
328+
self.collection = collection
326329
self.config.doc_maxlen = max_document_length
330+
327331
if index_name is not None:
328332
if self.index_name is not None:
329333
print(
@@ -339,36 +343,6 @@ def index(
339343
)
340344
self.index_name = self.checkpoint + "new_index"
341345

342-
self.collection = collection
343-
344-
nbits = 2
345-
if len(self.collection) < 5000:
346-
nbits = 8
347-
elif len(self.collection) < 10000:
348-
nbits = 4
349-
self.config = ColBERTConfig.from_existing(
350-
self.config, ColBERTConfig(nbits=nbits, index_bsize=bsize)
351-
)
352-
353-
if len(self.collection) > 100000:
354-
self.config.kmeans_niters = 4
355-
elif len(self.collection) > 50000:
356-
self.config.kmeans_niters = 10
357-
else:
358-
self.config.kmeans_niters = 20
359-
360-
# Instruct colbert-ai to disable forking if nranks == 1
361-
self.config.avoid_fork_if_possible = True
362-
self.indexer = Indexer(
363-
checkpoint=self.checkpoint,
364-
config=self.config,
365-
verbose=self.verbose,
366-
)
367-
self.indexer.configure(avoid_fork_if_possible=True)
368-
self.indexer.index(
369-
name=self.index_name, collection=self.collection, overwrite=overwrite
370-
)
371-
372346
self.index_path = str(
373347
Path(self.run_config.root)
374348
/ Path(self.run_config.experiment)
@@ -378,25 +352,43 @@ def index(
378352
self.config.root = str(
379353
Path(self.run_config.root) / Path(self.run_config.experiment) / "indexes"
380354
)
381-
self._write_collection_to_file(
382-
self.collection, self.index_path + "/collection.json"
383-
)
384355

385356
self.pid_docid_map = pid_docid_map
386-
self._write_collection_to_file(
387-
self.pid_docid_map, self.index_path + "/pid_docid_map.json"
388-
)
389357

390358
# inverted mapping for returning full docs
391359
self.docid_pid_map = defaultdict(list)
392360
for pid, docid in self.pid_docid_map.items():
393361
self.docid_pid_map[docid].append(pid)
394362

395-
if docid_metadata_map is not None:
396-
self._write_collection_to_file(
397-
docid_metadata_map, self.index_path + "/docid_metadata_map.json"
398-
)
399-
self.docid_metadata_map = docid_metadata_map
363+
self.docid_metadata_map = docid_metadata_map
364+
365+
if torch.cuda.is_available():
366+
import faiss
367+
368+
if not hasattr(faiss, "StandardGpuResources"):
369+
print(
370+
"________________________________________________________________________________\n"
371+
"WARNING! You have a GPU available, but only `faiss-cpu` is currently installed.\n",
372+
"This means that indexing will be slow. To make use of your GPU.\n"
373+
"Please install `faiss-gpu` by running:\n"
374+
"pip uninstall --y faiss-cpu & pip install faiss-gpu\n",
375+
"________________________________________________________________________________",
376+
)
377+
print("Will continue with CPU indexing in 5 seconds...")
378+
time.sleep(5)
379+
380+
self.model_index = ModelIndexFactory.construct(
381+
"PLAID",
382+
self.config,
383+
self.checkpoint,
384+
self.collection,
385+
self.index_name,
386+
overwrite,
387+
self.verbose,
388+
bsize=bsize,
389+
)
390+
self.config = self.model_index.config
391+
self._save_index_metadata()
400392

401393
print("Done indexing!")
402394

ragatouille/models/index.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
from abc import ABC, abstractmethod
2+
from pathlib import Path
3+
from time import time
4+
from typing import Any, List, Literal, Optional, TypeAlias, Union
5+
6+
from colbert import Indexer
7+
from colbert.infra import ColBERTConfig
8+
9+
import torch
10+
11+
import srsly
12+
13+
14+
IndexType: TypeAlias = Literal["FLAT", "HNSW", "PLAID"]
15+
16+
17+
class ModelIndex(ABC):
18+
index_type: IndexType
19+
20+
def __init__(
21+
self,
22+
config: ColBERTConfig,
23+
) -> None:
24+
self.config = config
25+
26+
@staticmethod
27+
@abstractmethod
28+
def construct(
29+
config: ColBERTConfig,
30+
checkpoint: str,
31+
collection: List[str],
32+
index_name: Optional["str"] = None,
33+
overwrite: Union[bool, str] = "reuse",
34+
verbose: bool = True,
35+
**kwargs,
36+
) -> "ModelIndex":
37+
...
38+
39+
@staticmethod
40+
@abstractmethod
41+
def load_from_file(pretrained_model_path: Path) -> "ModelIndex":
42+
...
43+
44+
@abstractmethod
45+
def build(self) -> None:
46+
...
47+
48+
@abstractmethod
49+
def search(self) -> None:
50+
...
51+
52+
@abstractmethod
53+
def batch_search(self) -> None:
54+
...
55+
56+
@abstractmethod
57+
def add(self) -> None:
58+
...
59+
60+
@abstractmethod
61+
def delete(self) -> None:
62+
...
63+
64+
@abstractmethod
65+
def export(self) -> Optional[dict[str, Any]]:
66+
...
67+
68+
69+
class FLATModelIndex(ModelIndex):
70+
index_type = "FLAT"
71+
72+
73+
class HNSWModelIndex(ModelIndex):
74+
index_type = "HNSW"
75+
76+
77+
class PLAIDModelIndex(ModelIndex):
78+
index_type = "PLAID"
79+
80+
def __init__(self, config: ColBERTConfig) -> None:
81+
super().__init__(config)
82+
83+
@staticmethod
84+
def construct(
85+
config: ColBERTConfig,
86+
checkpoint: Union[str, Path],
87+
collection: List[str],
88+
index_name: Optional["str"] = None,
89+
overwrite: Union[bool, str] = "reuse",
90+
verbose: bool = True,
91+
**kwargs,
92+
) -> "PLAIDModelIndex":
93+
bsize = kwargs.get("bsize", 32)
94+
assert isinstance(bsize, int)
95+
96+
if torch.cuda.is_available():
97+
import faiss
98+
99+
if not hasattr(faiss, "StandardGpuResources"):
100+
print(
101+
"________________________________________________________________________________\n"
102+
"WARNING! You have a GPU available, but only `faiss-cpu` is currently installed.\n",
103+
"This means that indexing will be slow. To make use of your GPU.\n"
104+
"Please install `faiss-gpu` by running:\n"
105+
"pip uninstall --y faiss-cpu & pip install faiss-gpu\n",
106+
"________________________________________________________________________________",
107+
)
108+
print("Will continue with CPU indexing in 5 seconds...")
109+
time.sleep(5)
110+
111+
nbits = 2
112+
if len(collection) < 5000:
113+
nbits = 8
114+
elif len(collection) < 10000:
115+
nbits = 4
116+
config = ColBERTConfig.from_existing(
117+
config, ColBERTConfig(nbits=nbits, index_bsize=bsize)
118+
)
119+
120+
if len(collection) > 100000:
121+
config.kmeans_niters = 4
122+
elif len(collection) > 50000:
123+
config.kmeans_niters = 10
124+
else:
125+
config.kmeans_niters = 20
126+
127+
# Instruct colbert-ai to disable forking if nranks == 1
128+
config.avoid_fork_if_possible = True
129+
indexer = Indexer(
130+
checkpoint=checkpoint,
131+
config=config,
132+
verbose=verbose,
133+
)
134+
indexer.configure(avoid_fork_if_possible=True)
135+
indexer.index(name=index_name, collection=collection, overwrite=overwrite)
136+
return PLAIDModelIndex(config)
137+
138+
@staticmethod
139+
def load_from_file(pretrained_model_path: Path) -> "PLAIDModelIndex":
140+
raise NotImplementedError()
141+
142+
def build(self) -> None:
143+
raise NotImplementedError()
144+
145+
def search(self) -> None:
146+
raise NotImplementedError()
147+
148+
def batch_search(self) -> None:
149+
raise NotImplementedError()
150+
151+
def add(self) -> None:
152+
raise NotImplementedError()
153+
154+
def delete(self) -> None:
155+
raise NotImplementedError()
156+
157+
def export(self) -> Optional[dict[str, Any]]:
158+
raise NotImplementedError()
159+
160+
161+
class ModelIndexFactory:
162+
_MODEL_INDEX_BY_NAME = {
163+
"FLAT": FLATModelIndex,
164+
"HNSW": HNSWModelIndex,
165+
"PLAID": PLAIDModelIndex,
166+
}
167+
168+
@staticmethod
169+
def _raise_if_invalid_index_type(index_type: str) -> IndexType:
170+
if index_type not in ["FLAT", "HNSW", "PLAID"]:
171+
raise ValueError(
172+
f"Unsupported index_type `{index_type}`; it must be one of 'FLAT', 'HNSW', OR 'PLAID'"
173+
)
174+
return index_type # type: ignore
175+
176+
@staticmethod
177+
def construct(
178+
index_type: Union[Literal["auto"], IndexType],
179+
config: ColBERTConfig,
180+
checkpoint: str,
181+
collection: List[str],
182+
index_name: Optional["str"] = None,
183+
overwrite: Union[bool, str] = "reuse",
184+
verbose: bool = True,
185+
**kwargs,
186+
) -> ModelIndex:
187+
# Automatically choose the appropriate index for the desired "workload".
188+
if index_type == "auto":
189+
# NOTE: For now only PLAID indexes are supported.
190+
index_type = "PLAID"
191+
return ModelIndexFactory._MODEL_INDEX_BY_NAME[
192+
ModelIndexFactory._raise_if_invalid_index_type(index_type)
193+
].construct(
194+
config, checkpoint, collection, index_name, overwrite, verbose, **kwargs
195+
)
196+
197+
@staticmethod
198+
def _file_index_type(pretrained_model_path: Path) -> IndexType:
199+
try:
200+
index_type = srsly.read_json(str(pretrained_model_path / "metadata.json"))[
201+
"index_type"
202+
]
203+
assert isinstance(index_type, str)
204+
except KeyError:
205+
index_type = "PLAID"
206+
return ModelIndexFactory._raise_if_invalid_index_type(index_type)
207+
208+
@staticmethod
209+
def load_from_file(pretrained_model_path: Path) -> ModelIndex:
210+
index_type = ModelIndexFactory._file_index_type(pretrained_model_path)
211+
return ModelIndexFactory._MODEL_INDEX_BY_NAME[index_type].load_from_file(
212+
pretrained_model_path
213+
)

0 commit comments

Comments
 (0)