Skip to content

Commit

Permalink
nit: comment
Browse files Browse the repository at this point in the history
feat: rework kmeans to be closer to FAISS

chore: store kmeans functions as class attributes

fix: method assignment

chore: more memory efficient

lint

chore: lower bsize, resultd unaffected

feat: better batching, slower max doc count

chore: batch size safe for 8gb GPUs

chore: more elaborate warning

chore: use external lib to support minibatching, revert to homebrew later
  • Loading branch information
bclavie committed Mar 18, 2024
1 parent f6adedd commit d89656b
Show file tree
Hide file tree
Showing 3 changed files with 324 additions and 110 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ colbert-ai = "0.2.19"
langchain = "^0.1.0"
onnx = "^1.15.0"
srsly = "2.4.8"
fast-pytorch-kmeans= "0.2.0.1"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
Expand Down
77 changes: 44 additions & 33 deletions ragatouille/models/index.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from pathlib import Path
from time import time
from typing import Any, List, Literal, Optional, TypeVar, Union
Expand Down Expand Up @@ -33,7 +34,8 @@ def construct(
overwrite: Union[bool, str] = "reuse",
verbose: bool = True,
**kwargs,
) -> "ModelIndex": ...
) -> "ModelIndex":
...

@staticmethod
@abstractmethod
Expand All @@ -43,7 +45,8 @@ def load_from_file(
index_config: dict[str, Any],
config: ColBERTConfig,
verbose: bool = True,
) -> "ModelIndex": ...
) -> "ModelIndex":
...

@abstractmethod
def build(
Expand All @@ -53,7 +56,8 @@ def build(
index_name: Optional["str"] = None,
overwrite: Union[bool, str] = "reuse",
verbose: bool = True,
) -> None: ...
) -> None:
...

@abstractmethod
def search(
Expand All @@ -68,13 +72,16 @@ def search(
pids: Optional[List[int]] = None,
force_reload: bool = False,
**kwargs,
) -> list[tuple[list, list, list]]: ...
) -> list[tuple[list, list, list]]:
...

@abstractmethod
def _search(self, query: str, k: int, pids: Optional[List[int]] = None): ...
def _search(self, query: str, k: int, pids: Optional[List[int]] = None):
...

@abstractmethod
def _batch_search(self, query: list[str], k: int): ...
def _batch_search(self, query: list[str], k: int):
...

@abstractmethod
def add(
Expand All @@ -87,7 +94,8 @@ def add(
new_collection: List[str],
verbose: bool = True,
**kwargs,
) -> None: ...
) -> None:
...

@abstractmethod
def delete(
Expand All @@ -98,10 +106,12 @@ def delete(
index_name: str,
pids_to_remove: Union[TypeVar("T"), List[TypeVar("T")]],
verbose: bool = True,
) -> None: ...
) -> None:
...

@abstractmethod
def _export_config(self) -> dict[str, Any]: ...
def _export_config(self) -> dict[str, Any]:
...

def export_metadata(self) -> dict[str, Any]:
config = self._export_config()
Expand All @@ -120,6 +130,8 @@ class HNSWModelIndex(ModelIndex):
class PLAIDModelIndex(ModelIndex):
_DEFAULT_INDEX_BSIZE = 32
index_type = "PLAID"
faiss_kmeans = staticmethod(deepcopy(CollectionIndexer._train_kmeans))
pytorch_kmeans = staticmethod(torch_kmeans._train_kmeans)

def __init__(self, config: ColBERTConfig) -> None:
super().__init__(config)
Expand Down Expand Up @@ -171,36 +183,35 @@ def build(
self.config, ColBERTConfig(nbits=nbits, index_bsize=bsize)
)

if len(collection) > 100000:
self.config.kmeans_niters = 4
elif len(collection) > 50000:
self.config.kmeans_niters = 10
else:
self.config.kmeans_niters = 20

# Instruct colbert-ai to disable forking if nranks == 1
self.config.avoid_fork_if_possible = True

# Monkey-patch colbert-ai to avoid using FAISS
monkey_patching = False
if len(collection) < 500000 and kwargs.get("use_faiss", False) is False:
monkey_patching = (
len(collection) < 40000 and kwargs.get("use_faiss", False) is False
)
if monkey_patching:
print(
"---- WARNING! You are using PLAID with an experimental replacement for FAISS for greater compatibility ----"
)
print("This is a behaviour change from RAGatouille 0.8.0 onwards.")
print(
"This works fine for most users, but is slower than FAISS and slightly more approximate."
"This works fine for most users and smallish datasets, but can be considerably slower than FAISS and could cause worse results in some situations."
)
print(
"If you're confident with FAISS working issue-free on your machine, pass use_faiss=True to revert to the FAISS-using behaviour."
"If you're confident with FAISS working on your machine, pass use_faiss=True to revert to the FAISS-using behaviour."
)
print("--------------------")
if not hasattr(CollectionIndexer, "_original_train_kmeans"):
CollectionIndexer._original_train_kmeans = (
CollectionIndexer._train_kmeans
)
CollectionIndexer._train_kmeans = torch_kmeans._train_kmeans
monkey_patching = True
CollectionIndexer._train_kmeans = self.pytorch_kmeans

# Try to keep runtime stable -- these are values that empirically didn't degrade performance at all on 3 benchmarks.
# More tests required before warning can be removed.
if len(collection) > 20000:
self.config.means_niters = 4
elif len(collection) > 10000:
self.config.kmeans_niters = 8
else:
self.config.kmeans_niters = 10
try:
indexer = Indexer(
checkpoint=checkpoint,
Expand All @@ -216,15 +227,15 @@ def build(
f"PyTorch-based indexing did not succeed with error: {err}",
"! Reverting to using FAISS and attempting again...",
)
CollectionIndexer._train_kmeans = (
CollectionIndexer._original_train_kmeans
)
monkey_patching = False
if monkey_patching is False:
if hasattr(CollectionIndexer, "_original_train_kmeans"):
CollectionIndexer._train_kmeans = (
CollectionIndexer._original_train_kmeans
)
if len(collection) > 100000:
self.config.kmeans_niters = 4
elif len(collection) > 50000:
self.config.kmeans_niters = 10
else:
self.config.kmeans_niters = 20
CollectionIndexer._train_kmeans = self.faiss_kmeans
if torch.cuda.is_available():
import faiss

Expand Down
Loading

0 comments on commit d89656b

Please sign in to comment.