Skip to content

Commit

Permalink
Merge pull request #98 from beir-cellar/development
Browse files Browse the repository at this point in the history
Merge all latest development changes to master.

Changes added to the latest BEIR master branch:

1. Added support for the T5 reranking model: monoT5 rerankers.
2. Hugging Face Data loader for BEIR dataset. Uploaded all datasets on HF.
3. MultiGPU evaluation using SBERT parallel code.
4. Added HNSWSQ method in faiss retrieval methods.
5. Added dependency of datasets library within setup.py
6. Shorten README.md significantly and created a separate BEIR wiki.
  • Loading branch information
thakur-nandan authored Jun 30, 2022
2 parents 16806d0 + d33182c commit 1a1e6ab
Show file tree
Hide file tree
Showing 17 changed files with 983 additions and 390 deletions.
464 changes: 89 additions & 375 deletions README.md

Large diffs are not rendered by default.

118 changes: 118 additions & 0 deletions beir/datasets/data_loader_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from collections import defaultdict
from typing import Dict, Tuple
import os
import logging
from datasets import load_dataset, Value, Features

logger = logging.getLogger(__name__)


class HFDataLoader:

def __init__(self, hf_repo: str = None, hf_repo_qrels: str = None, data_folder: str = None, prefix: str = None, corpus_file: str = "corpus.jsonl", query_file: str = "queries.jsonl",
qrels_folder: str = "qrels", qrels_file: str = "", streaming: bool = False, keep_in_memory: bool = False):
self.corpus = {}
self.queries = {}
self.qrels = {}
self.hf_repo = hf_repo
if hf_repo:
logger.warn("A huggingface repository is provided. This will override the data_folder, prefix and *_file arguments.")
self.hf_repo_qrels = hf_repo_qrels if hf_repo_qrels else hf_repo + "-qrels"
else:
# data folder would contain these files:
# (1) fiqa/corpus.jsonl (format: jsonlines)
# (2) fiqa/queries.jsonl (format: jsonlines)
# (3) fiqa/qrels/test.tsv (format: tsv ("\t"))
if prefix:
query_file = prefix + "-" + query_file
qrels_folder = prefix + "-" + qrels_folder

self.corpus_file = os.path.join(data_folder, corpus_file) if data_folder else corpus_file
self.query_file = os.path.join(data_folder, query_file) if data_folder else query_file
self.qrels_folder = os.path.join(data_folder, qrels_folder) if data_folder else None
self.qrels_file = qrels_file
self.streaming = streaming
self.keep_in_memory = keep_in_memory

@staticmethod
def check(fIn: str, ext: str):
if not os.path.exists(fIn):
raise ValueError("File {} not present! Please provide accurate file.".format(fIn))

if not fIn.endswith(ext):
raise ValueError("File {} must be present with extension {}".format(fIn, ext))

def load(self, split="test") -> Tuple[Dict[str, Dict[str, str]], Dict[str, str], Dict[str, Dict[str, int]]]:

if not self.hf_repo:
self.qrels_file = os.path.join(self.qrels_folder, split + ".tsv")
self.check(fIn=self.corpus_file, ext="jsonl")
self.check(fIn=self.query_file, ext="jsonl")
self.check(fIn=self.qrels_file, ext="tsv")

if not len(self.corpus):
logger.info("Loading Corpus...")
self._load_corpus()
logger.info("Loaded %d %s Documents.", len(self.corpus), split.upper())
logger.info("Doc Example: %s", self.corpus[0])

if not len(self.queries):
logger.info("Loading Queries...")
self._load_queries()

self._load_qrels(split)
# filter queries with no qrels
qrels_dict = defaultdict(dict)

def qrels_dict_init(row):
qrels_dict[row['query-id']][row['corpus-id']] = int(row['score'])
self.qrels.map(qrels_dict_init)
self.qrels = qrels_dict
self.queries = self.queries.filter(lambda x: x['id'] in self.qrels)
logger.info("Loaded %d %s Queries.", len(self.queries), split.upper())
logger.info("Query Example: %s", self.queries[0])

return self.corpus, self.queries, self.qrels

def load_corpus(self) -> Dict[str, Dict[str, str]]:
if not self.hf_repo:
self.check(fIn=self.corpus_file, ext="jsonl")

if not len(self.corpus):
logger.info("Loading Corpus...")
self._load_corpus()
logger.info("Loaded %d %s Documents.", len(self.corpus))
logger.info("Doc Example: %s", self.corpus[0])

return self.corpus

def _load_corpus(self):
if self.hf_repo:
corpus_ds = load_dataset(self.hf_repo, 'corpus', keep_in_memory=self.keep_in_memory, streaming=self.streaming)
else:
corpus_ds = load_dataset('json', data_files=self.corpus_file, streaming=self.streaming, keep_in_memory=self.keep_in_memory)
corpus_ds = next(iter(corpus_ds.values())) # get first split
corpus_ds = corpus_ds.cast_column('_id', Value('string'))
corpus_ds = corpus_ds.rename_column('_id', 'id')
corpus_ds = corpus_ds.remove_columns([col for col in corpus_ds.column_names if col not in ['id', 'text', 'title']])
self.corpus = corpus_ds

def _load_queries(self):
if self.hf_repo:
queries_ds = load_dataset(self.hf_repo, 'queries', keep_in_memory=self.keep_in_memory, streaming=self.streaming)
else:
queries_ds = load_dataset('json', data_files=self.query_file, streaming=self.streaming, keep_in_memory=self.keep_in_memory)
queries_ds = next(iter(queries_ds.values())) # get first split
queries_ds = queries_ds.cast_column('_id', Value('string'))
queries_ds = queries_ds.rename_column('_id', 'id')
queries_ds = queries_ds.remove_columns([col for col in queries_ds.column_names if col not in ['id', 'text']])
self.queries = queries_ds

def _load_qrels(self, split):
if self.hf_repo:
qrels_ds = load_dataset(self.hf_repo_qrels, keep_in_memory=self.keep_in_memory, streaming=self.streaming)[split]
else:
qrels_ds = load_dataset('csv', data_files=self.qrels_file, delimiter='\t', keep_in_memory=self.keep_in_memory)
features = Features({'query-id': Value('string'), 'corpus-id': Value('string'), 'score': Value('float')})
qrels_ds = qrels_ds.cast(features)
self.qrels = qrels_ds
3 changes: 2 additions & 1 deletion beir/reranking/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .cross_encoder import CrossEncoder
from .cross_encoder import CrossEncoder
from .mono_t5 import MonoT5
162 changes: 162 additions & 0 deletions beir/reranking/models/mono_t5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Majority of the code has been copied from PyGaggle MonoT5 implementation
# https://github.com/castorini/pygaggle/blob/master/pygaggle/rerank/transformer.py

from transformers import (AutoTokenizer,
AutoModelForSeq2SeqLM,
PreTrainedModel,
PreTrainedTokenizer,
T5ForConditionalGeneration)
from typing import List, Union, Tuple, Mapping, Optional
from dataclasses import dataclass
from tqdm.autonotebook import trange
import torch


TokenizerReturnType = Mapping[str, Union[torch.Tensor, List[int],
List[List[int]],
List[List[str]]]]

@dataclass
class QueryDocumentBatch:
query: str
documents: List[str]
output: Optional[TokenizerReturnType] = None

def __len__(self):
return len(self.documents)

class QueryDocumentBatchTokenizer:
def __init__(self,
tokenizer: PreTrainedTokenizer,
pattern: str = '{query} {document}',
**tokenizer_kwargs):
self.tokenizer = tokenizer
self.tokenizer_kwargs = tokenizer_kwargs
self.pattern = pattern

def encode(self, strings: List[str]):
assert self.tokenizer and self.tokenizer_kwargs is not None, \
'mixin used improperly'
ret = self.tokenizer.batch_encode_plus(strings,
**self.tokenizer_kwargs)
ret['tokens'] = list(map(self.tokenizer.tokenize, strings))
return ret

def traverse_query_document(
self, batch_input: Tuple[str, List[str]], batch_size: int):
query, doc_texts = batch_input[0], batch_input[1]
for batch_idx in range(0, len(doc_texts), batch_size):
docs = doc_texts[batch_idx:batch_idx + batch_size]
outputs = self.encode([self.pattern.format(
query=query,
document=doc) for doc in docs])
yield QueryDocumentBatch(query, docs, outputs)

class T5BatchTokenizer(QueryDocumentBatchTokenizer):
def __init__(self, *args, **kwargs):
kwargs['pattern'] = 'Query: {query} Document: {document} Relevant:'
if 'return_attention_mask' not in kwargs:
kwargs['return_attention_mask'] = True
if 'padding' not in kwargs:
kwargs['padding'] = 'longest'
if 'truncation' not in kwargs:
kwargs['truncation'] = True
if 'return_tensors' not in kwargs:
kwargs['return_tensors'] = 'pt'
if 'max_length' not in kwargs:
kwargs['max_length'] = 512
super().__init__(*args, **kwargs)


@torch.no_grad()
def greedy_decode(model: PreTrainedModel,
input_ids: torch.Tensor,
length: int,
attention_mask: torch.Tensor = None,
return_last_logits: bool = True) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
decode_ids = torch.full((input_ids.size(0), 1),
model.config.decoder_start_token_id,
dtype=torch.long).to(input_ids.device)
encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask)
next_token_logits = None
for _ in range(length):
model_inputs = model.prepare_inputs_for_generation(
decode_ids,
encoder_outputs=encoder_outputs,
past=None,
attention_mask=attention_mask,
use_cache=True)
outputs = model(**model_inputs) # (batch_size, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size, vocab_size)
decode_ids = torch.cat([decode_ids,
next_token_logits.max(1)[1].unsqueeze(-1)],
dim=-1)
if return_last_logits:
return decode_ids, next_token_logits
return decode_ids


class MonoT5:
def __init__(self,
model_path: str,
tokenizer: QueryDocumentBatchTokenizer = None,
use_amp = True,
token_false = None,
token_true = None):
self.model = self.get_model(model_path)
self.tokenizer = tokenizer or self.get_tokenizer(model_path)
self.token_false_id, self.token_true_id = self.get_prediction_tokens(
model_path, self.tokenizer, token_false, token_true)
self.model_path = model_path
self.device = next(self.model.parameters(), None).device
self.use_amp = use_amp

@staticmethod
def get_model(model_path: str, *args, device: str = None, **kwargs) -> T5ForConditionalGeneration:
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device(device)
return AutoModelForSeq2SeqLM.from_pretrained(model_path, *args, **kwargs).to(device).eval()

@staticmethod
def get_tokenizer(model_path: str, *args, **kwargs) -> T5BatchTokenizer:
return T5BatchTokenizer(
AutoTokenizer.from_pretrained(model_path, use_fast=False, *args, **kwargs)
)

@staticmethod
def get_prediction_tokens(model_path: str, tokenizer, token_false, token_true):
if (token_false and token_true):
token_false_id = tokenizer.tokenizer.get_vocab()[token_false]
token_true_id = tokenizer.tokenizer.get_vocab()[token_true]
return token_false_id, token_true_id

def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 32, **kwargs) -> List[float]:

sentence_dict, queries, scores = {}, [], []

# T5 model requires a batch of single query and top-k documents
for (query, doc_text) in sentences:
if query not in sentence_dict:
sentence_dict[query] = []
queries.append(query) # Preserves order of queries
sentence_dict[query].append(doc_text)

for start_idx in trange(0, len(queries), 1): # Take one query at a time
batch_input = (queries[start_idx], sentence_dict[queries[start_idx]]) # (single query, top-k docs)
for batch in self.tokenizer.traverse_query_document(batch_input, batch_size):
with torch.cuda.amp.autocast(enabled=self.use_amp):
input_ids = batch.output['input_ids'].to(self.device)
attn_mask = batch.output['attention_mask'].to(self.device)
_, batch_scores = greedy_decode(self.model,
input_ids,
length=1,
attention_mask=attn_mask,
return_last_logits=True)

batch_scores = batch_scores[:, [self.token_false_id, self.token_true_id]]
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
batch_log_probs = batch_scores[:, 1].tolist()
scores.extend(batch_log_probs)

assert len(scores) == len(sentences) # Sanity check, should be equal
return scores
48 changes: 46 additions & 2 deletions beir/retrieval/models/sentence_bert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from sentence_transformers import SentenceTransformer
from torch import Tensor
import torch.multiprocessing as mp
from typing import List, Dict, Union, Tuple
import numpy as np
import logging
from datasets import Dataset
from tqdm import tqdm

logger = logging.getLogger(__name__)


class SentenceBERT:
def __init__(self, model_path: Union[str, Tuple] = None, sep: str = " ", **kwargs):
Expand All @@ -15,9 +22,46 @@ def __init__(self, model_path: Union[str, Tuple] = None, sep: str = " ", **kwarg
self.q_model = SentenceTransformer(model_path[0])
self.doc_model = SentenceTransformer(model_path[1])

def start_multi_process_pool(self, target_devices: List[str] = None) -> Dict[str, object]:
logger.info("Start multi-process pool on devices: {}".format(', '.join(map(str, target_devices))))

ctx = mp.get_context('spawn')
input_queue = ctx.Queue()
output_queue = ctx.Queue()
processes = []

for process_id, device_name in enumerate(target_devices):
p = ctx.Process(target=SentenceTransformer._encode_multi_process_worker, args=(process_id, device_name, self.doc_model, input_queue, output_queue), daemon=True)
p.start()
processes.append(p)

return {'input': input_queue, 'output': output_queue, 'processes': processes}

def stop_multi_process_pool(self, pool: Dict[str, object]):
output_queue = pool['output']
[output_queue.get() for _ in range(len(pool['processes']))]
return self.doc_model.stop_multi_process_pool(pool)

def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
return self.q_model.encode(queries, batch_size=batch_size, **kwargs)

def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 8, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
def encode_corpus(self, corpus: Union[List[Dict[str, str]], Dict[str, List]], batch_size: int = 8, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
if type(corpus) is dict:
sentences = [(corpus["title"][i] + self.sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() for i in range(len(corpus['text']))]
else:
sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
return self.doc_model.encode(sentences, batch_size=batch_size, **kwargs)

## Encoding corpus in parallel
def encode_corpus_parallel(self, corpus: Union[List[Dict[str, str]], Dataset], pool: Dict[str, str], batch_size: int = 8, chunk_id: int = None, **kwargs):
if type(corpus) is dict:
sentences = [(corpus["title"][i] + self.sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() for i in range(len(corpus['text']))]
else:
sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]

if chunk_id is not None and chunk_id >= len(pool['processes']):
output_queue = pool['output']
output_queue.get()

input_queue = pool['input']
input_queue.put([chunk_id, batch_size, sentences])
5 changes: 3 additions & 2 deletions beir/retrieval/search/dense/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .exact_search import DenseRetrievalExactSearch
from .faiss_search import DenseRetrievalFaissSearch, BinaryFaissSearch, PQFaissSearch, HNSWFaissSearch, FlatIPFaissSearch, PCAFaissSearch, SQFaissSearch
from .exact_search import DenseRetrievalExactSearch
from .exact_search_multi_gpu import DenseRetrievalParallelExactSearch
from .faiss_search import DenseRetrievalFaissSearch, BinaryFaissSearch, PQFaissSearch, HNSWFaissSearch, HNSWSQFaissSearch, FlatIPFaissSearch, PCAFaissSearch, SQFaissSearch
2 changes: 1 addition & 1 deletion beir/retrieval/search/dense/exact_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def search(self,
cos_scores[torch.isnan(cos_scores)] = -1

#Get top-k values
cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(cos_scores, min(top_k+1, len(cos_scores[0])), dim=1, largest=True, sorted=return_sorted)
cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(cos_scores, min(top_k+1, len(cos_scores[1])), dim=1, largest=True, sorted=return_sorted)
cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist()
cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist()

Expand Down
Loading

0 comments on commit 1a1e6ab

Please sign in to comment.