Skip to content

Fix add_to_index overwrites and duplicate paths #157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Mar 5, 2024
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f444700
fix: avoid overwriting the old index with just the new docs
adharm Feb 11, 2024
4aa0952
chore: ruff linting
adharm Feb 11, 2024
f0b5a7f
Merge branch 'main' into fix_add_to_index
adharm Feb 11, 2024
3d0e675
test: added assertion for checking old docs when adding to index
adharm Feb 11, 2024
e9a4385
chore: added tsv and jsonl to gitignore
adharm Feb 11, 2024
c06b14b
test: use correct var in add_to_index assertions
adharm Feb 11, 2024
d3f5d32
fix: correctly update pimd_docid_map
adharm Feb 12, 2024
3ba34e6
chore: add .devcontainers dir to gitignore
adharm Feb 12, 2024
8238f83
chore: removed commented out test for returning full documents
adharm Feb 12, 2024
017e17e
Merge branch 'main' into fix_add_to_index
adharm Feb 23, 2024
6c53400
fix: corrected index_bsize to bsize when being passed to ColBERTConfig
adharm Feb 23, 2024
f68cd8d
Revert "fix: corrected index_bsize to bsize when being passed to ColB…
adharm Feb 23, 2024
5944cc5
fix: duplicate path issue in add_to_index
adharm Feb 24, 2024
cceb0ac
test: cleaned up CRUD tests to be more verbose on failure + added log…
adharm Feb 24, 2024
514966e
style: ruff formatting
adharm Feb 24, 2024
0e1250d
test: updated add_to_index tests
adharm Feb 27, 2024
e09d725
style: ruff formatting
adharm Feb 27, 2024
2fd6795
fix: :bug: fixed index not being written to disk if IndexUpdater is i…
adharm Feb 27, 2024
faa9eb0
Merge branch 'main' into fix_add_to_index
adharm Feb 27, 2024
7d3ff09
style: ruff format
adharm Feb 28, 2024
348ea6c
fix: fixed path issues by reverting previous deletion
adharm Feb 28, 2024
00ebc51
test: removed add to empty index test
adharm Feb 28, 2024
0b15265
Merge branch 'main' into fix_add_to_index
adharm Feb 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -115,6 +115,9 @@ dmypy.json
# Cython debug symbols
cython_debug/

# data files
*.tsv
*.jsonl

.mypy.ipynb_checkpoints
.mkdocs.yml
@@ -126,4 +129,6 @@ archive/

local/

.vscode/
.vscode/

.devcontainer/
4 changes: 3 additions & 1 deletion ragatouille/RAGTrainer.py
Original file line number Diff line number Diff line change
@@ -157,7 +157,9 @@ def prepare_training_data(
)
if len(self.data_processor.training_triplets) == 0:
if mine_hard_negatives:
print("Warning: No training triplets were generated with setting mine_hard_negatives=='True'. This may be due to the data being too small or the hard negative miner not being able to find enough hard negatives.")
print(
"Warning: No training triplets were generated with setting mine_hard_negatives=='True'. This may be due to the data being too small or the hard negative miner not being able to find enough hard negatives."
)
self.data_processor.process_raw_data(
data_type=data_type,
raw_data=raw_data,
4 changes: 3 additions & 1 deletion ragatouille/data/training_data_processor.py
Original file line number Diff line number Diff line change
@@ -125,7 +125,9 @@ def _get_new_negatives(self, query, passages, mine_hard_negatives, n_new_negativ
else:
new_negatives = [
x
for x in random.sample(self.collection, min(n_new_negatives, len(self.collection)))
for x in random.sample(
self.collection, min(n_new_negatives, len(self.collection))
)
if x not in passages["positives"] and x not in passages["negatives"]
]

173 changes: 69 additions & 104 deletions ragatouille/models/colbert.py
Original file line number Diff line number Diff line change
@@ -32,6 +32,7 @@ def __init__(
self.verbose = verbose
self.collection = None
self.pid_docid_map = None
self.docid_pid_map = None
self.docid_metadata_map = None
self.in_memory_docs = []
self.base_model_max_tokens = 512
@@ -51,35 +52,10 @@ def __init__(
)
split_root = str(pretrained_model_name_or_path).split("/")[:-1]
self.config.root = "/".join(split_root)
self.index_root = self.config.root
self.checkpoint = self.config.checkpoint
self.index_name = self.config.index_name
self.collection = self._get_collection_from_file(
str(pretrained_model_name_or_path / "collection.json")
)
try:
self.pid_docid_map = self._get_collection_from_file(
str(pretrained_model_name_or_path / "pid_docid_map.json")
)
# convert all keys to int when loading from file because saving converts to str
self.pid_docid_map = {
int(key): value for key, value in self.pid_docid_map.items()
}
self.docid_pid_map = defaultdict(list)
for pid, docid in self.pid_docid_map.items():
self.docid_pid_map[docid].append(pid)
if os.path.exists(
str(pretrained_model_name_or_path / "docid_metadata_map.json")
):
self.docid_metadata_map = self._get_collection_from_file(
str(pretrained_model_name_or_path / "docid_metadata_map.json")
)
except Exception:
print(
"WARNING: Could not load pid_docid_map or docid_metadata_map from index!",
"This is likely because you are loading an old index.",
)
self.pid_docid_map = defaultdict(lambda: None)
self.docid_metadata_map = defaultdict(lambda: None)
self._get_collection_files_from_disk(pretrained_model_name_or_path)
# TODO: Modify root assignment when loading from HF

else:
@@ -98,7 +74,7 @@ def __init__(
self.checkpoint = pretrained_model_name_or_path
self.index_name = index_name
self.config.experiment = "colbert"
self.config.root = ".ragatouille/"
self.config.root = self.index_root

if not training_mode:
self.inference_ckpt = Checkpoint(
@@ -112,11 +88,42 @@ def __init__(
self.run_context.__enter__() # Manually enter the context
self.searcher = None

def _get_collection_from_file(self, collection_path: str):
return srsly.read_json(collection_path)
def _invert_pid_docid_map(self) -> Dict[str, int]:
return {v: k for k, v in self.pid_docid_map.items()}

def _get_collection_files_from_disk(self, index_path: str):
self.collection = srsly.read_json(index_path / "collection.json")
if os.path.exists(str(index_path / "docid_metadata_map.json")):
self.docid_metadata_map = srsly.read_json(
str(index_path / "docid_metadata_map.json")
)
else:
self.docid_metadata_map = None

try:
self.pid_docid_map = srsly.read_json(str(index_path / "pid_docid_map.json"))
except FileNotFoundError as err:
raise FileNotFoundError(
"ERROR: Could not load pid_docid_map from index!",
"This is likely because you are loading an older, incompatible index.",
) from err

# convert all keys to int when loading from file because saving converts to str
self.pid_docid_map = {
int(key): value for key, value in self.pid_docid_map.items()
}
self.docid_pid_map = self._invert_pid_docid_map()

def _write_collection_files_to_disk(self):
srsly.write_json(self.index_path + "/collection.json", self.collection)
srsly.write_json(self.index_path + "/pid_docid_map.json", self.pid_docid_map)
if self.docid_metadata_map is not None:
srsly.write_json(
self.index_path + "/docid_metadata_map.json", self.docid_metadata_map
)

def _write_collection_to_file(self, collection, collection_path: str):
srsly.write_json(collection_path, collection)
# update the in-memory inverted map every time the files are saved to disk
self.docid_pid_map = self._invert_pid_docid_map()

def add_to_index(
self,
@@ -142,19 +149,18 @@ def add_to_index(
if self.loaded_from_index:
index_root = self.config.root
else:
index_root = str(
Path(self.config.root) / Path(self.config.experiment) / "indexes"
)
expected_path_segment = Path(self.config.experiment) / "indexes"
if str(expected_path_segment) in self.config.root:
index_root = self.config.root
else:
index_root = str(Path(self.config.root) / expected_path_segment)

if not self.collection:
self.collection = self._get_collection_from_file(
str(
Path(self.config.root)
/ Path(self.config.experiment)
/ "indexes"
/ self.index_name
/ "collection.json"
collection_path = Path(index_root) / self.index_name / "collection.json"
if collection_path.exists():
self._get_collection_files_from_disk(
str(Path(index_root) / self.index_name)
)
)

searcher = Searcher(
checkpoint=self.checkpoint,
@@ -170,20 +176,27 @@ def add_to_index(
new_documents_with_ids = [
{"content": doc, "document_id": new_pid_docid_map[pid]}
for pid, doc in enumerate(new_documents)
if new_pid_docid_map[pid] not in self.pid_docid_map
if new_pid_docid_map[pid] not in self.pid_docid_map.values()
]

if new_docid_metadata_map is not None:
self.docid_metadata_map = self.docid_metadata_map or {}
self.docid_metadata_map = self.docid_metadata_map or defaultdict(
lambda: None
)
self.docid_metadata_map.update(new_docid_metadata_map)

max_existing_pid = max(self.pid_docid_map.keys(), default=-1)
for idx, doc in enumerate(new_documents_with_ids, start=max_existing_pid + 1):
self.pid_docid_map[idx] = doc["document_id"]

combined_documents = self.collection + [
doc["content"] for doc in new_documents_with_ids
]

if current_len + new_doc_len < 5000 or new_doc_len > current_len * 0.05:
self.index(
[doc["content"] for doc in new_documents_with_ids],
{
pid: doc["document_id"]
for pid, doc in enumerate(new_documents_with_ids)
},
combined_documents,
self.pid_docid_map,
docid_metadata_map=self.docid_metadata_map,
index_name=self.index_name,
max_document_length=self.config.doc_maxlen,
@@ -197,38 +210,17 @@ def add_to_index(
updater = IndexUpdater(
config=self.config, searcher=searcher, checkpoint=self.checkpoint
)

updater.add([doc["content"] for doc in new_documents_with_ids])
updater.persist_to_disk()

self.pid_docid_map.update(
{pid: doc["document_id"] for pid, doc in enumerate(new_documents_with_ids)}
)
self.docid_pid_map = defaultdict(list)
for pid, docid in self.pid_docid_map.items():
self.docid_pid_map[docid].append(pid)

self._write_collection_to_file(
self.pid_docid_map, self.index_path + "/pid_docid_map.json"
)
if self.docid_metadata_map is not None:
self._write_collection_to_file(
self.docid_metadata_map, self.index_path + "/docid_metadata_map.json"
)
self._write_collection_files_to_disk()

print(
f"Successfully updated index with {len(new_documents_with_ids)} new documents!\n",
f"New index size: {current_len + len(new_documents_with_ids)}",
)

self.index_path = str(
Path(self.run_config.root)
/ Path(self.run_config.experiment)
/ "indexes"
/ self.index_name
)

return self.index_path

def delete_from_index(
self,
document_ids: Union[TypeVar("T"), List[TypeVar("T")]],
@@ -275,9 +267,6 @@ def delete_from_index(
for pid, docid in self.pid_docid_map.items()
if pid not in pids_to_remove
}
self.docid_pid_map = defaultdict(list)
for pid, docid in self.pid_docid_map.items():
self.docid_pid_map[docid].append(pid)

if self.docid_metadata_map is not None:
self.docid_metadata_map = {
@@ -286,16 +275,7 @@ def delete_from_index(
if docid not in document_ids
}

self._write_collection_to_file(
self.collection, self.index_path + "/collection.json"
)
self._write_collection_to_file(
self.pid_docid_map, self.index_path + "/pid_docid_map.json"
)
if self.docid_metadata_map is not None:
self._write_collection_to_file(
self.docid_metadata_map, self.index_path + "/docid_metadata_map.json"
)
self._write_collection_files_to_disk()

print(f"Successfully deleted documents with these IDs: {document_ids}")

@@ -340,6 +320,8 @@ def index(
self.index_name = self.checkpoint + "new_index"

self.collection = collection
self.pid_docid_map = pid_docid_map
self.docid_metadata_map = docid_metadata_map

nbits = 2
if len(self.collection) < 5000:
@@ -378,25 +360,8 @@ def index(
self.config.root = str(
Path(self.run_config.root) / Path(self.run_config.experiment) / "indexes"
)
self._write_collection_to_file(
self.collection, self.index_path + "/collection.json"
)

self.pid_docid_map = pid_docid_map
self._write_collection_to_file(
self.pid_docid_map, self.index_path + "/pid_docid_map.json"
)

# inverted mapping for returning full docs
self.docid_pid_map = defaultdict(list)
for pid, docid in self.pid_docid_map.items():
self.docid_pid_map[docid].append(pid)

if docid_metadata_map is not None:
self._write_collection_to_file(
docid_metadata_map, self.index_path + "/docid_metadata_map.json"
)
self.docid_metadata_map = docid_metadata_map
self._write_collection_files_to_disk()

print("Done indexing!")

112 changes: 53 additions & 59 deletions tests/test_pretrained_optional_args.py
Original file line number Diff line number Diff line change
@@ -219,69 +219,16 @@ def test_document_metadata_returned_in_search_results(
), "The metadata should not be returned in the results."


# def test_return_entire_document(index_creation_inputs, index_path_fixture):
# if index_creation_inputs["split_documents"] == True:
# RAG = RAGPretrainedModel.from_index(index_path_fixture)
# results = RAG.search(
# "when was miyazaki born",
# index_name=index_creation_inputs["index_name"],
# return_entire_document=True,
# )
# for result in results:
# assert (
# "entire_document" in result
# ), "The full document should be returned in the results."
# doc_id = result["document_id"]
# expected_document = index_creation_inputs["collection"][
# index_creation_inputs["document_ids"].index(doc_id)
# ]
# assert (
# result["entire_document"] == expected_document
# ), f"The document for document_id {doc_id} should match the provided document."
# else:
# assert True, "This test is only relevant for split documents."


# TODO: move this to a separate test file
def test_delete_from_index(
index_creation_inputs,
pid_docid_map_path_fixture,
document_metadata_path_fixture,
index_path_fixture,
):
RAG = RAGPretrainedModel.from_index(index_path_fixture)
deleted_doc_id = index_creation_inputs["document_ids"][0]
original_doc_ids = set(index_creation_inputs["document_ids"])
RAG.delete_from_index(
index_name=index_creation_inputs["index_name"],
document_ids=[deleted_doc_id],
)
pid_docid_map_data = srsly.read_json(pid_docid_map_path_fixture)
updated_document_ids = set(list(pid_docid_map_data.values()))
assert (
deleted_doc_id not in updated_document_ids
), "Deleted document ID should not be in the collection."
assert original_doc_ids - updated_document_ids == {
deleted_doc_id
}, "Only the deleted document ID should be missing from the collection."
if "document_metadatas" in index_creation_inputs:
document_metadata_dict = srsly.read_json(document_metadata_path_fixture)
assert (
deleted_doc_id not in document_metadata_dict
), "Deleted document ID should not be in the document metadata."
assert original_doc_ids - set(document_metadata_dict.keys()) == {
deleted_doc_id
}, "Only the deleted document ID should be missing from the document metadata."


# TODO: move this to a separate test file
def test_add_to_index(
# TODO: move this to a separate CRUD test file
# TODO: add checks for metadata and doc content
def test_add_to_existing_index(
index_creation_inputs,
document_metadata_path_fixture,
pid_docid_map_path_fixture,
index_path_fixture,
):
RAG = RAGPretrainedModel.from_index(index_path_fixture)
existing_doc_ids = index_creation_inputs["document_ids"]
new_doc_ids = ["mononoke", "sabaku_no_tami"]
new_docs = [
"Princess Mononoke (Japanese: もののけ姫, Hepburn: Mononoke-hime) is a 1997 Japanese animated epic historical fantasy film written and directed by Hayao Miyazaki and animated by Studio Ghibli for Tokuma Shoten, Nippon Television Network and Dentsu. The film stars the voices of Yōji Matsuda, Yuriko Ishida, Yūko Tanaka, Kaoru Kobayashi, Masahiko Nishimura, Tsunehiko Kamijo, Akihiro Miwa, Mitsuko Mori, and Hisaya Morishige.\nPrincess Mononoke is set in the late Muromachi period of Japan (approximately 1336 to 1573 AD) and includes fantasy elements. The story follows a young Emishi prince named Ashitaka, and his involvement in a struggle between the gods (kami) of a forest and the humans who consume its resources. The film deals with themes of Shinto and environmentalism.\nThe film was released in Japan on July 12, 1997, by Toho, and in the United States on October 29, 1999. This was the first Studio Ghibli film in the United States to be rated PG-13 by the MPA. It was a critical and commercial blockbuster, becoming the highest-grossing film in Japan of 1997, and also held Japan's box office record for domestic films until 2001's Spirited Away, another Miyazaki film. It was dubbed into English with a script by Neil Gaiman and initially distributed in North America by Miramax, where it sold well on home media despite not performing strongly at the box office. The film greatly increased Ghibli's popularity and influence outside Japan.",
@@ -301,10 +248,57 @@ def test_add_to_index(
document_ids = set(list(pid_docid_map_data.values()))

document_metadata_dict = srsly.read_json(document_metadata_path_fixture)
# check for new docs
for new_doc_id in new_doc_ids:
assert (
new_doc_id in document_ids
), f"New document ID {new_doc_id} should be in the pid_docid_map."
), f"New document ID '{new_doc_id}' should be in the pid_docid_map's document_ids:{document_ids}."

assert (
new_doc_id in document_metadata_dict
), f"New document ID {new_doc_id} should be in the document metadata."
), f"New document ID '{new_doc_id}' should be in the document metadata keys:{document_metadata_dict.keys}."

for existing_doc_id in existing_doc_ids:
assert (
existing_doc_id in document_ids
), f"Old document ID '{existing_doc_id}' should be in the pid_docid_map's document_ids:{document_ids}."

if "document_metadatas" in index_creation_inputs:
assert (
existing_doc_id in document_metadata_dict
), f"Old document ID '{existing_doc_id}' should be in the document metadata keys:{document_metadata_dict.keys}."


# TODO: move this to a separate CRUD test file
def test_delete_from_index(
index_creation_inputs,
pid_docid_map_path_fixture,
document_metadata_path_fixture,
index_path_fixture,
):
RAG = RAGPretrainedModel.from_index(index_path_fixture)
deleted_doc_id = index_creation_inputs["document_ids"][0]
original_doc_ids = set(index_creation_inputs["document_ids"])
RAG.delete_from_index(
index_name=index_creation_inputs["index_name"],
document_ids=[deleted_doc_id],
)
pid_docid_map_data = srsly.read_json(pid_docid_map_path_fixture)
updated_document_ids = set(list(pid_docid_map_data.values()))

assert (
deleted_doc_id not in updated_document_ids
), f"Deleted document ID '{deleted_doc_id}' should not be in the pid_docid_map's document_ids: {updated_document_ids}."

assert (
original_doc_ids - updated_document_ids == {deleted_doc_id}
), f"Only the deleted document ID '{deleted_doc_id}' should be missing from the pid_docid_map's document_ids: {updated_document_ids}."

if "document_metadatas" in index_creation_inputs:
document_metadata_dict = srsly.read_json(document_metadata_path_fixture)
assert (
deleted_doc_id not in document_metadata_dict
), f"Deleted document ID '{deleted_doc_id}' should not be in the document metadata: {document_metadata_dict.keys}."
assert (
original_doc_ids - set(document_metadata_dict.keys()) == {deleted_doc_id}
), f"Only the deleted document ID '{deleted_doc_id}' should be missing from the document metadata: {document_metadata_dict.keys}."