Skip to content

Commit

Permalink
support baseline import/export of indexes
Browse files Browse the repository at this point in the history
  • Loading branch information
jlscheerer committed Feb 25, 2024
1 parent d50d75c commit 1f96ecc
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 22 deletions.
21 changes: 18 additions & 3 deletions ragatouille/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import srsly
import torch
from colbert import Indexer, IndexUpdater, Searcher, Trainer
from colbert import IndexUpdater, Searcher, Trainer
from colbert.infra import ColBERTConfig, Run, RunConfig
from colbert.modeling.checkpoint import Checkpoint

Expand All @@ -34,7 +34,6 @@ def __init__(
self.collection = None
self.pid_docid_map = None
self.docid_metadata_map = None
self.in_memory_docs = []
self.base_model_max_tokens = 512
if n_gpu == -1:
n_gpu = 1 if torch.cuda.device_count() == 0 else torch.cuda.device_count()
Expand All @@ -47,7 +46,11 @@ def __init__(
ckpt_config = ColBERTConfig.load_from_index(
str(pretrained_model_name_or_path)
)
self.config = ckpt_config
# Use pretrained_model_name_or_path, and set the config for this.
self.model_index = ModelIndexFactory.load_from_file(
self.index_path, index_name, ckpt_config
)
self.config = self.model_index.config
self.run_config = RunConfig(
nranks=n_gpu, experiment=self.config.experiment, root=self.config.root
)
Expand Down Expand Up @@ -302,6 +305,18 @@ def delete_from_index(
print(f"Successfully deleted documents with these IDs: {document_ids}")

def _save_index_metadata(self):
assert self.model_index is not None

model_metadata = srsly.read_json(self.index_path + "/metadata.json")
index_config = self.model_index.export_metadata()
index_config["index_name"] = self.index_name
# Ensure that the additional metadata we store does not collide with anything else.
model_metadata["RAGatouille"] = {"index_config": index_config} # type: ignore
self._write_collection_to_file(
model_metadata,
self.index_path + "/metadata.json",
)

self._write_collection_to_file(
self.collection, self.index_path + "/collection.json"
)
Expand Down
63 changes: 44 additions & 19 deletions ragatouille/models/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@ def construct(

@staticmethod
@abstractmethod
def load_from_file(pretrained_model_path: Path) -> "ModelIndex":
def load_from_file(
index_path: str,
index_name: Optional[str],
index_config: dict[str, Any],
config: ColBERTConfig,
verbose: bool = True,
) -> "ModelIndex":
...

@abstractmethod
Expand All @@ -62,9 +68,14 @@ def delete(self) -> None:
...

@abstractmethod
def export(self) -> Optional[dict[str, Any]]:
def _export_config(self) -> dict[str, Any]:
...

def export_metadata(self) -> dict[str, Any]:
config = self._export_config()
config["index_type"] = self.index_type
return config


class FLATModelIndex(ModelIndex):
index_type = "FLAT"
Expand Down Expand Up @@ -136,8 +147,14 @@ def construct(
return PLAIDModelIndex(config)

@staticmethod
def load_from_file(pretrained_model_path: Path) -> "PLAIDModelIndex":
raise NotImplementedError()
def load_from_file(
index_path: str,
index_name: Optional[str],
index_config: dict[str, Any],
config: ColBERTConfig,
verbose: bool = True,
) -> "PLAIDModelIndex":
return PLAIDModelIndex(config)

def build(self) -> None:
raise NotImplementedError()
Expand All @@ -154,8 +171,8 @@ def add(self) -> None:
def delete(self) -> None:
raise NotImplementedError()

def export(self) -> Optional[dict[str, Any]]:
raise NotImplementedError()
def _export_config(self) -> dict[str, Any]:
return {}


class ModelIndexFactory:
Expand Down Expand Up @@ -195,19 +212,27 @@ def construct(
)

@staticmethod
def _file_index_type(pretrained_model_path: Path) -> IndexType:
def load_from_file(
index_path: str,
index_name: Optional[str],
config: ColBERTConfig,
verbose: bool = True,
) -> ModelIndex:
metadata = srsly.read_json(index_path + "/metadata.json")
try:
index_type = srsly.read_json(str(pretrained_model_path / "metadata.json"))[
"index_type"
]
assert isinstance(index_type, str)
index_config = metadata["RAGatouille"]["index_config"] # type: ignore
except KeyError:
index_type = "PLAID"
return ModelIndexFactory._raise_if_invalid_index_type(index_type)

@staticmethod
def load_from_file(pretrained_model_path: Path) -> ModelIndex:
index_type = ModelIndexFactory._file_index_type(pretrained_model_path)
return ModelIndexFactory._MODEL_INDEX_BY_NAME[index_type].load_from_file(
pretrained_model_path
if verbose:
print(
f"Constructing default index configuration for index `{index_name}` as it does not contain RAGatouille specific metadata."
)
index_config = {
"index_type": "PLAID",
"index_name": index_name,
}
index_name = (
index_name if index_name is not None else index_config["index_name"] # type: ignore
)
return ModelIndexFactory._MODEL_INDEX_BY_NAME[
ModelIndexFactory._raise_if_invalid_index_type(index_config["index_type"]) # type: ignore
].load_from_file(index_path, index_name, index_config, config, verbose)

0 comments on commit 1f96ecc

Please sign in to comment.