Skip to content

Commit

Permalink
QoL for monot5 reranking
Browse files Browse the repository at this point in the history
  • Loading branch information
Thilina Rajapakse committed May 27, 2024
1 parent 3e2944f commit f89bc1b
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 1 deletion.
2 changes: 1 addition & 1 deletion simpletransformers/t5/t5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,7 @@ def predict(self, to_predict, reranking_eval=False):
else:
return outputs

def rerank(self, eval_data, qrels=None):
def rerank(self, eval_data, qrels=None, run_dict=None, beir_format=False):
"""
Used with monoT5 style models for reranking
"""
Expand Down
92 changes: 92 additions & 0 deletions simpletransformers/t5/t5_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
import logging
import os
import pickle
from multiprocessing import Pool
from os import truncate
from typing import Tuple
import warnings

import pandas as pd
import torch
Expand Down Expand Up @@ -279,3 +281,93 @@ def __len__(self):

def __getitem__(self, index):
return self.examples[index]


def convert_beir_to_monot5_format(data, run_dict=None, top_k=None, include_title=False, save_path=None):
"""
Utility function to convert BEIR format to MonoT5 format
Args:
data: A directory containing a dataset in the BEIR format
run_dict: Path to a run file to build a reranking dataset. If not provided, all documents are considered.
run_dict should be a json file with the following format:
{
"query_id1": ["doc_id1": score1, "doc_id2": score2, ...],
"query_id2": ["doc_id1": score1, "doc_id2": score2, ...],
...
}
top_k: Number of documents to consider for reranking. Only used if run_dict is provided.
include_title: Whether to include the title of the document in the MonoT5 format.
save_path: Path to save the converted dataset. If not provided, the dataset is returned as a DataFrame.
"""

if run_dict:
with open(run_dict, "r") as f:
run_dict = json.load(f)
if top_k:
for query_id in run_dict:
run_dict[query_id] = dict(sorted(run_dict[query_id].items(), key=lambda x: x[1], reverse=True)[:top_k])

# Make sure both query_id and doc_id are strings
updated_dict = {}
for query_id in run_dict:
updated_dict[str(query_id)] = {str(k): v for k, v in run_dict[query_id].items()}

run_dict = updated_dict
else:
if top_k:
warnings.warn(
"top_k is only used when run_dict is provided. Ignoring top_k."
)

queries_df = pd.read_json(os.path.join(data, "queries.jsonl"), lines=True)
corpus_df = pd.read_json(os.path.join(data, "corpus.jsonl"), lines=True)

queries_df["_id"] = queries_df["_id"].astype(str)
corpus_df["_id"] = corpus_df["_id"].astype(str)

if include_title:
corpus_df["text"] = corpus_df["title"] + " " + corpus_df["text"]

queries_df = queries_df.set_index("_id")
corpus_df = corpus_df.set_index("_id")

if run_dict:
reranking_data = []
for query_id in tqdm(run_dict, total=len(run_dict)):
for passage_id in run_dict[query_id]:
reranking_data.append(
{
"query_id": query_id,
"query": queries_df.loc[query_id]["text"],
"passage_id": passage_id,
"passage": corpus_df.loc[passage_id]["text"],
}
)
else:
reranking_data = []
for query_id, query in tqdm(queries_df.iterrows(), total=len(queries_df)):
for passage_id, passage in corpus_df.iterrows():
reranking_data.append(
{
"query_id": query_id,
"query": query["text"],
"passage_id": passage_id,
"passage": passage["text"],
}
)

# MonoT5 format DF should have the columns: query_id, passage_id, input_text
# input_text should be in the format: "Query: <query> Document: <document> Relevant:"
reranking_df = pd.DataFrame(reranking_data)
reranking_df["input_text"] = reranking_df.apply(
lambda x: f"Query: {x['query']} Document: {x['passage']} Relevant:", axis=1
)

reranking_df = reranking_df[["query_id", "passage", "input_text"]]

if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
reranking_df.to_csv(save_path, sep="\t", index=False)

return reranking_df

0 comments on commit f89bc1b

Please sign in to comment.