Skip to content

Commit

Permalink
Merge pull request #1366 from hanhainebula/master
Browse files Browse the repository at this point in the history
optimize evaluation process in evaluator.py
  • Loading branch information
hanhainebula authored Feb 10, 2025
2 parents fcdf889 + 3b09d80 commit b2871ac
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
12 changes: 8 additions & 4 deletions FlagEmbedding/abc/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,9 @@ def __call__(
no_reranker_search_results_dict[split] = search_results
retriever.stop_multi_process_pool()
eval_results_save_path = os.path.join(no_reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
retriever_eval_results = self.evaluate_results(no_reranker_search_results_save_dir, k_values=k_values)
self.output_eval_results_to_json(retriever_eval_results, eval_results_save_path)
if not os.path.exists(eval_results_save_path) or self.overwrite or flag:
retriever_eval_results = self.evaluate_results(no_reranker_search_results_save_dir, k_values=k_values)
self.output_eval_results_to_json(retriever_eval_results, eval_results_save_path)

# Reranking Stage
if reranker is not None:
Expand All @@ -229,6 +230,7 @@ def __call__(
for split in splits
}

flag = False
for split in splits:
rerank_search_results_save_path = os.path.join(
reranker_search_results_save_dir, save_name.format(split=split)
Expand All @@ -237,6 +239,7 @@ def __call__(
if os.path.exists(rerank_search_results_save_path) and not self.overwrite:
continue

flag = True
rerank_search_results = reranker(
corpus=corpus,
queries=queries_dict[split],
Expand All @@ -256,8 +259,9 @@ def __call__(
)
reranker.stop_multi_process_pool()
eval_results_save_path = os.path.join(reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
reranker_eval_results = self.evaluate_results(reranker_search_results_save_dir, k_values=k_values)
self.output_eval_results_to_json(reranker_eval_results, eval_results_save_path)
if not os.path.exists(eval_results_save_path) or self.overwrite or flag:
reranker_eval_results = self.evaluate_results(reranker_search_results_save_dir, k_values=k_values)
self.output_eval_results_to_json(reranker_eval_results, eval_results_save_path)

@staticmethod
def save_search_results(
Expand Down
24 changes: 16 additions & 8 deletions FlagEmbedding/evaluation/beir/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,9 @@ def __call__(
no_reranker_search_results_dict[split] = search_results
retriever.stop_multi_process_pool()
eval_results_save_path = os.path.join(no_reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
retriever_eval_results = self.evaluate_results(no_reranker_search_results_save_dir, k_values=k_values)
self.output_eval_results_to_json(retriever_eval_results, eval_results_save_path)
if not os.path.exists(eval_results_save_path) or self.overwrite or flag:
retriever_eval_results = self.evaluate_results(no_reranker_search_results_save_dir, k_values=k_values)
self.output_eval_results_to_json(retriever_eval_results, eval_results_save_path)

# Reranking Stage
if reranker is not None:
Expand All @@ -181,6 +182,7 @@ def __call__(
for split in splits
}

flag = False
for split in splits:
rerank_search_results_save_path = os.path.join(
reranker_search_results_save_dir, save_name.format(split=split)
Expand All @@ -189,6 +191,7 @@ def __call__(
if os.path.exists(rerank_search_results_save_path) and not self.overwrite:
continue

flag = True
rerank_search_results = reranker(
corpus=corpus,
queries=queries_dict[split],
Expand All @@ -208,8 +211,9 @@ def __call__(
sub_dataset_name=sub_dataset_name,
)
eval_results_save_path = os.path.join(reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
reranker_eval_results = self.evaluate_results(reranker_search_results_save_dir, k_values=k_values)
self.output_eval_results_to_json(reranker_eval_results, eval_results_save_path)
if not os.path.exists(eval_results_save_path) or self.overwrite or flag:
reranker_eval_results = self.evaluate_results(reranker_search_results_save_dir, k_values=k_values)
self.output_eval_results_to_json(reranker_eval_results, eval_results_save_path)
else:
for sub_dataset_name in sub_dataset_names:
if dataset_name is not None:
Expand Down Expand Up @@ -291,8 +295,9 @@ def __call__(
)
no_reranker_search_results_dict[split] = search_results
eval_results_save_path = os.path.join(no_reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
retriever_eval_results = self.evaluate_results(no_reranker_search_results_save_dir, k_values=k_values)
self.output_eval_results_to_json(retriever_eval_results, eval_results_save_path)
if not os.path.exists(eval_results_save_path) or self.overwrite or flag:
retriever_eval_results = self.evaluate_results(no_reranker_search_results_save_dir, k_values=k_values)
self.output_eval_results_to_json(retriever_eval_results, eval_results_save_path)

# Reranking Stage
if reranker is not None:
Expand All @@ -308,6 +313,7 @@ def __call__(
for split in splits
}

flag = False
for split in splits:
rerank_search_results_save_path = os.path.join(
reranker_search_results_save_dir, save_name.format(split=split)
Expand All @@ -316,6 +322,7 @@ def __call__(
if os.path.exists(rerank_search_results_save_path) and not self.overwrite:
continue

flag = True
rerank_search_results = reranker(
corpus=corpus,
queries=queries_dict[split],
Expand All @@ -335,8 +342,9 @@ def __call__(
sub_dataset_name=sub_dataset_name,
)
eval_results_save_path = os.path.join(reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
reranker_eval_results = self.evaluate_results(reranker_search_results_save_dir, k_values=k_values)
self.output_eval_results_to_json(reranker_eval_results, eval_results_save_path)
if not os.path.exists(eval_results_save_path) or self.overwrite or flag:
reranker_eval_results = self.evaluate_results(reranker_search_results_save_dir, k_values=k_values)
self.output_eval_results_to_json(reranker_eval_results, eval_results_save_path)
if reranker is not None:
reranker.stop_multi_process_pool()

Expand Down

0 comments on commit b2871ac

Please sign in to comment.