-
Notifications
You must be signed in to change notification settings - Fork 200
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #98 from beir-cellar/development
Merge all latest development changes to master. Changes added to the latest BEIR master branch: 1. Added support for the T5 reranking model: monoT5 rerankers. 2. Hugging Face Data loader for BEIR dataset. Uploaded all datasets on HF. 3. MultiGPU evaluation using SBERT parallel code. 4. Added HNSWSQ method in faiss retrieval methods. 5. Added dependency of datasets library within setup.py 6. Shorten README.md significantly and created a separate BEIR wiki.
- Loading branch information
Showing
17 changed files
with
983 additions
and
390 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from collections import defaultdict | ||
from typing import Dict, Tuple | ||
import os | ||
import logging | ||
from datasets import load_dataset, Value, Features | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class HFDataLoader: | ||
|
||
def __init__(self, hf_repo: str = None, hf_repo_qrels: str = None, data_folder: str = None, prefix: str = None, corpus_file: str = "corpus.jsonl", query_file: str = "queries.jsonl", | ||
qrels_folder: str = "qrels", qrels_file: str = "", streaming: bool = False, keep_in_memory: bool = False): | ||
self.corpus = {} | ||
self.queries = {} | ||
self.qrels = {} | ||
self.hf_repo = hf_repo | ||
if hf_repo: | ||
logger.warn("A huggingface repository is provided. This will override the data_folder, prefix and *_file arguments.") | ||
self.hf_repo_qrels = hf_repo_qrels if hf_repo_qrels else hf_repo + "-qrels" | ||
else: | ||
# data folder would contain these files: | ||
# (1) fiqa/corpus.jsonl (format: jsonlines) | ||
# (2) fiqa/queries.jsonl (format: jsonlines) | ||
# (3) fiqa/qrels/test.tsv (format: tsv ("\t")) | ||
if prefix: | ||
query_file = prefix + "-" + query_file | ||
qrels_folder = prefix + "-" + qrels_folder | ||
|
||
self.corpus_file = os.path.join(data_folder, corpus_file) if data_folder else corpus_file | ||
self.query_file = os.path.join(data_folder, query_file) if data_folder else query_file | ||
self.qrels_folder = os.path.join(data_folder, qrels_folder) if data_folder else None | ||
self.qrels_file = qrels_file | ||
self.streaming = streaming | ||
self.keep_in_memory = keep_in_memory | ||
|
||
@staticmethod | ||
def check(fIn: str, ext: str): | ||
if not os.path.exists(fIn): | ||
raise ValueError("File {} not present! Please provide accurate file.".format(fIn)) | ||
|
||
if not fIn.endswith(ext): | ||
raise ValueError("File {} must be present with extension {}".format(fIn, ext)) | ||
|
||
def load(self, split="test") -> Tuple[Dict[str, Dict[str, str]], Dict[str, str], Dict[str, Dict[str, int]]]: | ||
|
||
if not self.hf_repo: | ||
self.qrels_file = os.path.join(self.qrels_folder, split + ".tsv") | ||
self.check(fIn=self.corpus_file, ext="jsonl") | ||
self.check(fIn=self.query_file, ext="jsonl") | ||
self.check(fIn=self.qrels_file, ext="tsv") | ||
|
||
if not len(self.corpus): | ||
logger.info("Loading Corpus...") | ||
self._load_corpus() | ||
logger.info("Loaded %d %s Documents.", len(self.corpus), split.upper()) | ||
logger.info("Doc Example: %s", self.corpus[0]) | ||
|
||
if not len(self.queries): | ||
logger.info("Loading Queries...") | ||
self._load_queries() | ||
|
||
self._load_qrels(split) | ||
# filter queries with no qrels | ||
qrels_dict = defaultdict(dict) | ||
|
||
def qrels_dict_init(row): | ||
qrels_dict[row['query-id']][row['corpus-id']] = int(row['score']) | ||
self.qrels.map(qrels_dict_init) | ||
self.qrels = qrels_dict | ||
self.queries = self.queries.filter(lambda x: x['id'] in self.qrels) | ||
logger.info("Loaded %d %s Queries.", len(self.queries), split.upper()) | ||
logger.info("Query Example: %s", self.queries[0]) | ||
|
||
return self.corpus, self.queries, self.qrels | ||
|
||
def load_corpus(self) -> Dict[str, Dict[str, str]]: | ||
if not self.hf_repo: | ||
self.check(fIn=self.corpus_file, ext="jsonl") | ||
|
||
if not len(self.corpus): | ||
logger.info("Loading Corpus...") | ||
self._load_corpus() | ||
logger.info("Loaded %d %s Documents.", len(self.corpus)) | ||
logger.info("Doc Example: %s", self.corpus[0]) | ||
|
||
return self.corpus | ||
|
||
def _load_corpus(self): | ||
if self.hf_repo: | ||
corpus_ds = load_dataset(self.hf_repo, 'corpus', keep_in_memory=self.keep_in_memory, streaming=self.streaming) | ||
else: | ||
corpus_ds = load_dataset('json', data_files=self.corpus_file, streaming=self.streaming, keep_in_memory=self.keep_in_memory) | ||
corpus_ds = next(iter(corpus_ds.values())) # get first split | ||
corpus_ds = corpus_ds.cast_column('_id', Value('string')) | ||
corpus_ds = corpus_ds.rename_column('_id', 'id') | ||
corpus_ds = corpus_ds.remove_columns([col for col in corpus_ds.column_names if col not in ['id', 'text', 'title']]) | ||
self.corpus = corpus_ds | ||
|
||
def _load_queries(self): | ||
if self.hf_repo: | ||
queries_ds = load_dataset(self.hf_repo, 'queries', keep_in_memory=self.keep_in_memory, streaming=self.streaming) | ||
else: | ||
queries_ds = load_dataset('json', data_files=self.query_file, streaming=self.streaming, keep_in_memory=self.keep_in_memory) | ||
queries_ds = next(iter(queries_ds.values())) # get first split | ||
queries_ds = queries_ds.cast_column('_id', Value('string')) | ||
queries_ds = queries_ds.rename_column('_id', 'id') | ||
queries_ds = queries_ds.remove_columns([col for col in queries_ds.column_names if col not in ['id', 'text']]) | ||
self.queries = queries_ds | ||
|
||
def _load_qrels(self, split): | ||
if self.hf_repo: | ||
qrels_ds = load_dataset(self.hf_repo_qrels, keep_in_memory=self.keep_in_memory, streaming=self.streaming)[split] | ||
else: | ||
qrels_ds = load_dataset('csv', data_files=self.qrels_file, delimiter='\t', keep_in_memory=self.keep_in_memory) | ||
features = Features({'query-id': Value('string'), 'corpus-id': Value('string'), 'score': Value('float')}) | ||
qrels_ds = qrels_ds.cast(features) | ||
self.qrels = qrels_ds |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .cross_encoder import CrossEncoder | ||
from .cross_encoder import CrossEncoder | ||
from .mono_t5 import MonoT5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# Majority of the code has been copied from PyGaggle MonoT5 implementation | ||
# https://github.com/castorini/pygaggle/blob/master/pygaggle/rerank/transformer.py | ||
|
||
from transformers import (AutoTokenizer, | ||
AutoModelForSeq2SeqLM, | ||
PreTrainedModel, | ||
PreTrainedTokenizer, | ||
T5ForConditionalGeneration) | ||
from typing import List, Union, Tuple, Mapping, Optional | ||
from dataclasses import dataclass | ||
from tqdm.autonotebook import trange | ||
import torch | ||
|
||
|
||
TokenizerReturnType = Mapping[str, Union[torch.Tensor, List[int], | ||
List[List[int]], | ||
List[List[str]]]] | ||
|
||
@dataclass | ||
class QueryDocumentBatch: | ||
query: str | ||
documents: List[str] | ||
output: Optional[TokenizerReturnType] = None | ||
|
||
def __len__(self): | ||
return len(self.documents) | ||
|
||
class QueryDocumentBatchTokenizer: | ||
def __init__(self, | ||
tokenizer: PreTrainedTokenizer, | ||
pattern: str = '{query} {document}', | ||
**tokenizer_kwargs): | ||
self.tokenizer = tokenizer | ||
self.tokenizer_kwargs = tokenizer_kwargs | ||
self.pattern = pattern | ||
|
||
def encode(self, strings: List[str]): | ||
assert self.tokenizer and self.tokenizer_kwargs is not None, \ | ||
'mixin used improperly' | ||
ret = self.tokenizer.batch_encode_plus(strings, | ||
**self.tokenizer_kwargs) | ||
ret['tokens'] = list(map(self.tokenizer.tokenize, strings)) | ||
return ret | ||
|
||
def traverse_query_document( | ||
self, batch_input: Tuple[str, List[str]], batch_size: int): | ||
query, doc_texts = batch_input[0], batch_input[1] | ||
for batch_idx in range(0, len(doc_texts), batch_size): | ||
docs = doc_texts[batch_idx:batch_idx + batch_size] | ||
outputs = self.encode([self.pattern.format( | ||
query=query, | ||
document=doc) for doc in docs]) | ||
yield QueryDocumentBatch(query, docs, outputs) | ||
|
||
class T5BatchTokenizer(QueryDocumentBatchTokenizer): | ||
def __init__(self, *args, **kwargs): | ||
kwargs['pattern'] = 'Query: {query} Document: {document} Relevant:' | ||
if 'return_attention_mask' not in kwargs: | ||
kwargs['return_attention_mask'] = True | ||
if 'padding' not in kwargs: | ||
kwargs['padding'] = 'longest' | ||
if 'truncation' not in kwargs: | ||
kwargs['truncation'] = True | ||
if 'return_tensors' not in kwargs: | ||
kwargs['return_tensors'] = 'pt' | ||
if 'max_length' not in kwargs: | ||
kwargs['max_length'] = 512 | ||
super().__init__(*args, **kwargs) | ||
|
||
|
||
@torch.no_grad() | ||
def greedy_decode(model: PreTrainedModel, | ||
input_ids: torch.Tensor, | ||
length: int, | ||
attention_mask: torch.Tensor = None, | ||
return_last_logits: bool = True) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | ||
decode_ids = torch.full((input_ids.size(0), 1), | ||
model.config.decoder_start_token_id, | ||
dtype=torch.long).to(input_ids.device) | ||
encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask) | ||
next_token_logits = None | ||
for _ in range(length): | ||
model_inputs = model.prepare_inputs_for_generation( | ||
decode_ids, | ||
encoder_outputs=encoder_outputs, | ||
past=None, | ||
attention_mask=attention_mask, | ||
use_cache=True) | ||
outputs = model(**model_inputs) # (batch_size, cur_len, vocab_size) | ||
next_token_logits = outputs[0][:, -1, :] # (batch_size, vocab_size) | ||
decode_ids = torch.cat([decode_ids, | ||
next_token_logits.max(1)[1].unsqueeze(-1)], | ||
dim=-1) | ||
if return_last_logits: | ||
return decode_ids, next_token_logits | ||
return decode_ids | ||
|
||
|
||
class MonoT5: | ||
def __init__(self, | ||
model_path: str, | ||
tokenizer: QueryDocumentBatchTokenizer = None, | ||
use_amp = True, | ||
token_false = None, | ||
token_true = None): | ||
self.model = self.get_model(model_path) | ||
self.tokenizer = tokenizer or self.get_tokenizer(model_path) | ||
self.token_false_id, self.token_true_id = self.get_prediction_tokens( | ||
model_path, self.tokenizer, token_false, token_true) | ||
self.model_path = model_path | ||
self.device = next(self.model.parameters(), None).device | ||
self.use_amp = use_amp | ||
|
||
@staticmethod | ||
def get_model(model_path: str, *args, device: str = None, **kwargs) -> T5ForConditionalGeneration: | ||
device = device or ('cuda' if torch.cuda.is_available() else 'cpu') | ||
device = torch.device(device) | ||
return AutoModelForSeq2SeqLM.from_pretrained(model_path, *args, **kwargs).to(device).eval() | ||
|
||
@staticmethod | ||
def get_tokenizer(model_path: str, *args, **kwargs) -> T5BatchTokenizer: | ||
return T5BatchTokenizer( | ||
AutoTokenizer.from_pretrained(model_path, use_fast=False, *args, **kwargs) | ||
) | ||
|
||
@staticmethod | ||
def get_prediction_tokens(model_path: str, tokenizer, token_false, token_true): | ||
if (token_false and token_true): | ||
token_false_id = tokenizer.tokenizer.get_vocab()[token_false] | ||
token_true_id = tokenizer.tokenizer.get_vocab()[token_true] | ||
return token_false_id, token_true_id | ||
|
||
def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 32, **kwargs) -> List[float]: | ||
|
||
sentence_dict, queries, scores = {}, [], [] | ||
|
||
# T5 model requires a batch of single query and top-k documents | ||
for (query, doc_text) in sentences: | ||
if query not in sentence_dict: | ||
sentence_dict[query] = [] | ||
queries.append(query) # Preserves order of queries | ||
sentence_dict[query].append(doc_text) | ||
|
||
for start_idx in trange(0, len(queries), 1): # Take one query at a time | ||
batch_input = (queries[start_idx], sentence_dict[queries[start_idx]]) # (single query, top-k docs) | ||
for batch in self.tokenizer.traverse_query_document(batch_input, batch_size): | ||
with torch.cuda.amp.autocast(enabled=self.use_amp): | ||
input_ids = batch.output['input_ids'].to(self.device) | ||
attn_mask = batch.output['attention_mask'].to(self.device) | ||
_, batch_scores = greedy_decode(self.model, | ||
input_ids, | ||
length=1, | ||
attention_mask=attn_mask, | ||
return_last_logits=True) | ||
|
||
batch_scores = batch_scores[:, [self.token_false_id, self.token_true_id]] | ||
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) | ||
batch_log_probs = batch_scores[:, 1].tolist() | ||
scores.extend(batch_log_probs) | ||
|
||
assert len(scores) == len(sentences) # Sanity check, should be equal | ||
return scores |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .exact_search import DenseRetrievalExactSearch | ||
from .faiss_search import DenseRetrievalFaissSearch, BinaryFaissSearch, PQFaissSearch, HNSWFaissSearch, FlatIPFaissSearch, PCAFaissSearch, SQFaissSearch | ||
from .exact_search import DenseRetrievalExactSearch | ||
from .exact_search_multi_gpu import DenseRetrievalParallelExactSearch | ||
from .faiss_search import DenseRetrievalFaissSearch, BinaryFaissSearch, PQFaissSearch, HNSWFaissSearch, HNSWSQFaissSearch, FlatIPFaissSearch, PCAFaissSearch, SQFaissSearch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.