Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix add_to_index overwrites and duplicate paths #157

Merged
merged 23 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f444700
fix: avoid overwriting the old index with just the new docs
adharm Feb 11, 2024
4aa0952
chore: ruff linting
adharm Feb 11, 2024
f0b5a7f
Merge branch 'main' into fix_add_to_index
adharm Feb 11, 2024
3d0e675
test: added assertion for checking old docs when adding to index
adharm Feb 11, 2024
e9a4385
chore: added tsv and jsonl to gitignore
adharm Feb 11, 2024
c06b14b
test: use correct var in add_to_index assertions
adharm Feb 11, 2024
d3f5d32
fix: correctly update pimd_docid_map
adharm Feb 12, 2024
3ba34e6
chore: add .devcontainers dir to gitignore
adharm Feb 12, 2024
8238f83
chore: removed commented out test for returning full documents
adharm Feb 12, 2024
017e17e
Merge branch 'main' into fix_add_to_index
adharm Feb 23, 2024
6c53400
fix: corrected index_bsize to bsize when being passed to ColBERTConfig
adharm Feb 23, 2024
f68cd8d
Revert "fix: corrected index_bsize to bsize when being passed to ColB…
adharm Feb 23, 2024
5944cc5
fix: duplicate path issue in add_to_index
adharm Feb 24, 2024
cceb0ac
test: cleaned up CRUD tests to be more verbose on failure + added log…
adharm Feb 24, 2024
514966e
style: ruff formatting
adharm Feb 24, 2024
0e1250d
test: updated add_to_index tests
adharm Feb 27, 2024
e09d725
style: ruff formatting
adharm Feb 27, 2024
2fd6795
fix: :bug: fixed index not being written to disk if IndexUpdater is i…
adharm Feb 27, 2024
faa9eb0
Merge branch 'main' into fix_add_to_index
adharm Feb 27, 2024
7d3ff09
style: ruff format
adharm Feb 28, 2024
348ea6c
fix: fixed path issues by reverting previous deletion
adharm Feb 28, 2024
00ebc51
test: removed add to empty index test
adharm Feb 28, 2024
0b15265
Merge branch 'main' into fix_add_to_index
adharm Feb 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ dmypy.json
# Cython debug symbols
cython_debug/

# data files
*.tsv
*.jsonl

.mypy.ipynb_checkpoints
.mkdocs.yml
Expand All @@ -126,4 +129,6 @@ archive/

local/

.vscode/
.vscode/

.devcontainer/
4 changes: 3 additions & 1 deletion ragatouille/RAGTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def prepare_training_data(
)
if len(self.data_processor.training_triplets) == 0:
if mine_hard_negatives:
print("Warning: No training triplets were generated with setting mine_hard_negatives=='True'. This may be due to the data being too small or the hard negative miner not being able to find enough hard negatives.")
print(
"Warning: No training triplets were generated with setting mine_hard_negatives=='True'. This may be due to the data being too small or the hard negative miner not being able to find enough hard negatives."
)
self.data_processor.process_raw_data(
data_type=data_type,
raw_data=raw_data,
Expand Down
4 changes: 3 additions & 1 deletion ragatouille/data/training_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def _get_new_negatives(self, query, passages, mine_hard_negatives, n_new_negativ
else:
new_negatives = [
x
for x in random.sample(self.collection, min(n_new_negatives, len(self.collection)))
for x in random.sample(
self.collection, min(n_new_negatives, len(self.collection))
)
if x not in passages["positives"] and x not in passages["negatives"]
]

Expand Down
173 changes: 69 additions & 104 deletions ragatouille/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
self.verbose = verbose
self.collection = None
self.pid_docid_map = None
self.docid_pid_map = None
self.docid_metadata_map = None
self.in_memory_docs = []
self.base_model_max_tokens = 512
Expand All @@ -51,35 +52,10 @@ def __init__(
)
split_root = str(pretrained_model_name_or_path).split("/")[:-1]
self.config.root = "/".join(split_root)
self.index_root = self.config.root
self.checkpoint = self.config.checkpoint
self.index_name = self.config.index_name
self.collection = self._get_collection_from_file(
str(pretrained_model_name_or_path / "collection.json")
)
try:
self.pid_docid_map = self._get_collection_from_file(
str(pretrained_model_name_or_path / "pid_docid_map.json")
)
# convert all keys to int when loading from file because saving converts to str
self.pid_docid_map = {
int(key): value for key, value in self.pid_docid_map.items()
}
self.docid_pid_map = defaultdict(list)
for pid, docid in self.pid_docid_map.items():
self.docid_pid_map[docid].append(pid)
if os.path.exists(
str(pretrained_model_name_or_path / "docid_metadata_map.json")
):
self.docid_metadata_map = self._get_collection_from_file(
str(pretrained_model_name_or_path / "docid_metadata_map.json")
)
except Exception:
print(
"WARNING: Could not load pid_docid_map or docid_metadata_map from index!",
"This is likely because you are loading an old index.",
)
self.pid_docid_map = defaultdict(lambda: None)
self.docid_metadata_map = defaultdict(lambda: None)
self._get_collection_files_from_disk(pretrained_model_name_or_path)
# TODO: Modify root assignment when loading from HF

else:
Expand All @@ -98,7 +74,7 @@ def __init__(
self.checkpoint = pretrained_model_name_or_path
self.index_name = index_name
self.config.experiment = "colbert"
self.config.root = ".ragatouille/"
self.config.root = self.index_root

if not training_mode:
self.inference_ckpt = Checkpoint(
Expand All @@ -112,11 +88,42 @@ def __init__(
self.run_context.__enter__() # Manually enter the context
self.searcher = None

def _get_collection_from_file(self, collection_path: str):
return srsly.read_json(collection_path)
def _invert_pid_docid_map(self) -> Dict[str, int]:
return {v: k for k, v in self.pid_docid_map.items()}

def _get_collection_files_from_disk(self, index_path: str):
self.collection = srsly.read_json(index_path / "collection.json")
if os.path.exists(str(index_path / "docid_metadata_map.json")):
self.docid_metadata_map = srsly.read_json(
str(index_path / "docid_metadata_map.json")
)
else:
self.docid_metadata_map = None

try:
self.pid_docid_map = srsly.read_json(str(index_path / "pid_docid_map.json"))
except FileNotFoundError as err:
raise FileNotFoundError(
"ERROR: Could not load pid_docid_map from index!",
"This is likely because you are loading an older, incompatible index.",
) from err

# convert all keys to int when loading from file because saving converts to str
self.pid_docid_map = {
int(key): value for key, value in self.pid_docid_map.items()
}
self.docid_pid_map = self._invert_pid_docid_map()

def _write_collection_files_to_disk(self):
srsly.write_json(self.index_path + "/collection.json", self.collection)
srsly.write_json(self.index_path + "/pid_docid_map.json", self.pid_docid_map)
if self.docid_metadata_map is not None:
srsly.write_json(
self.index_path + "/docid_metadata_map.json", self.docid_metadata_map
)

def _write_collection_to_file(self, collection, collection_path: str):
srsly.write_json(collection_path, collection)
# update the in-memory inverted map every time the files are saved to disk
self.docid_pid_map = self._invert_pid_docid_map()

def add_to_index(
self,
Expand All @@ -142,19 +149,18 @@ def add_to_index(
if self.loaded_from_index:
index_root = self.config.root
else:
index_root = str(
Path(self.config.root) / Path(self.config.experiment) / "indexes"
)
expected_path_segment = Path(self.config.experiment) / "indexes"
if str(expected_path_segment) in self.config.root:
index_root = self.config.root
else:
index_root = str(Path(self.config.root) / expected_path_segment)

if not self.collection:
self.collection = self._get_collection_from_file(
str(
Path(self.config.root)
/ Path(self.config.experiment)
/ "indexes"
/ self.index_name
/ "collection.json"
collection_path = Path(index_root) / self.index_name / "collection.json"
if collection_path.exists():
self._get_collection_files_from_disk(
str(Path(index_root) / self.index_name)
)
)

searcher = Searcher(
checkpoint=self.checkpoint,
Expand All @@ -170,20 +176,27 @@ def add_to_index(
new_documents_with_ids = [
{"content": doc, "document_id": new_pid_docid_map[pid]}
for pid, doc in enumerate(new_documents)
if new_pid_docid_map[pid] not in self.pid_docid_map
if new_pid_docid_map[pid] not in self.pid_docid_map.values()
]

if new_docid_metadata_map is not None:
self.docid_metadata_map = self.docid_metadata_map or {}
self.docid_metadata_map = self.docid_metadata_map or defaultdict(
lambda: None
)
self.docid_metadata_map.update(new_docid_metadata_map)

max_existing_pid = max(self.pid_docid_map.keys(), default=-1)
for idx, doc in enumerate(new_documents_with_ids, start=max_existing_pid + 1):
self.pid_docid_map[idx] = doc["document_id"]

combined_documents = self.collection + [
doc["content"] for doc in new_documents_with_ids
]

if current_len + new_doc_len < 5000 or new_doc_len > current_len * 0.05:
self.index(
[doc["content"] for doc in new_documents_with_ids],
{
pid: doc["document_id"]
for pid, doc in enumerate(new_documents_with_ids)
},
combined_documents,
self.pid_docid_map,
docid_metadata_map=self.docid_metadata_map,
index_name=self.index_name,
max_document_length=self.config.doc_maxlen,
Expand All @@ -197,38 +210,17 @@ def add_to_index(
updater = IndexUpdater(
config=self.config, searcher=searcher, checkpoint=self.checkpoint
)

updater.add([doc["content"] for doc in new_documents_with_ids])
updater.persist_to_disk()

self.pid_docid_map.update(
{pid: doc["document_id"] for pid, doc in enumerate(new_documents_with_ids)}
)
self.docid_pid_map = defaultdict(list)
for pid, docid in self.pid_docid_map.items():
self.docid_pid_map[docid].append(pid)

self._write_collection_to_file(
self.pid_docid_map, self.index_path + "/pid_docid_map.json"
)
if self.docid_metadata_map is not None:
self._write_collection_to_file(
self.docid_metadata_map, self.index_path + "/docid_metadata_map.json"
)
self._write_collection_files_to_disk()

print(
f"Successfully updated index with {len(new_documents_with_ids)} new documents!\n",
f"New index size: {current_len + len(new_documents_with_ids)}",
)

self.index_path = str(
Path(self.run_config.root)
/ Path(self.run_config.experiment)
/ "indexes"
/ self.index_name
)

return self.index_path

def delete_from_index(
self,
document_ids: Union[TypeVar("T"), List[TypeVar("T")]],
Expand Down Expand Up @@ -275,9 +267,6 @@ def delete_from_index(
for pid, docid in self.pid_docid_map.items()
if pid not in pids_to_remove
}
self.docid_pid_map = defaultdict(list)
for pid, docid in self.pid_docid_map.items():
self.docid_pid_map[docid].append(pid)

if self.docid_metadata_map is not None:
self.docid_metadata_map = {
Expand All @@ -286,16 +275,7 @@ def delete_from_index(
if docid not in document_ids
}

self._write_collection_to_file(
self.collection, self.index_path + "/collection.json"
)
self._write_collection_to_file(
self.pid_docid_map, self.index_path + "/pid_docid_map.json"
)
if self.docid_metadata_map is not None:
self._write_collection_to_file(
self.docid_metadata_map, self.index_path + "/docid_metadata_map.json"
)
self._write_collection_files_to_disk()

print(f"Successfully deleted documents with these IDs: {document_ids}")

Expand Down Expand Up @@ -340,6 +320,8 @@ def index(
self.index_name = self.checkpoint + "new_index"

self.collection = collection
self.pid_docid_map = pid_docid_map
self.docid_metadata_map = docid_metadata_map

nbits = 2
if len(self.collection) < 5000:
Expand Down Expand Up @@ -378,25 +360,8 @@ def index(
self.config.root = str(
Path(self.run_config.root) / Path(self.run_config.experiment) / "indexes"
)
self._write_collection_to_file(
self.collection, self.index_path + "/collection.json"
)

self.pid_docid_map = pid_docid_map
self._write_collection_to_file(
self.pid_docid_map, self.index_path + "/pid_docid_map.json"
)

# inverted mapping for returning full docs
self.docid_pid_map = defaultdict(list)
for pid, docid in self.pid_docid_map.items():
self.docid_pid_map[docid].append(pid)

if docid_metadata_map is not None:
self._write_collection_to_file(
docid_metadata_map, self.index_path + "/docid_metadata_map.json"
)
self.docid_metadata_map = docid_metadata_map
self._write_collection_files_to_disk()

print("Done indexing!")

Expand Down
Loading
Loading