Skip to content

Commit

Permalink
(almost) all the methods now accept argument 'sep'; now 'mnrl.mnrl' a…
Browse files Browse the repository at this point in the history
…nd 'evaluation.evaluate' both loads SBERT by 'sbert.load_sbert;added support to k_values in evaluation.evaluate; added readme & example for evaluation.evaluate'
  • Loading branch information
kwang2049 committed Dec 19, 2021
1 parent 631999a commit 8724d2d
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 31 deletions.
44 changes: 44 additions & 0 deletions gpl/toolkit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,47 @@ What this [`gpl.toolkit.reformat.dpr_lik`](https://github.com/UKPLab/gpl/blob/72
2. Then within the `word_embedding_model`, it traces along the path `DPRQuestionEncoder` -> `DPREncoder` -> `BertModel` -> `BertPooler` to get the `BertModel` and the `BertPooler` (the final linear layer of DPR models);
3. Compose everything (including the `BertModel`, a CLS pooling layer, the `BertPooler`) together again into a SBERT-format checkpoint.
4. Save the reformatted checkpoint into the `--output_path`.

## evaluation
We can both evaluate a checkpoint (1) within the `gpl.train` workflow or (2) in an independent routine. For the case (2), we can simply run this, with the example of SciFact (of course, writing a new Python script with `from gpl.toolkit import evaluate; evaluate(...)` is also supported):
```bash
export dataset="scifact"
if [ ! -d "$dataset" ]; then
wget https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/$dataset.zip
unzip $dataset.zip
fi

python -m gpl.toolkit.evaluation \
--data_path $dataset \
--output_dir $dataset \
--model_name_or_path "GPL/msmarco-distilbert-margin-mse" \
--max_seq_length 350 \
--score_function "dot" \
--sep " " \
--k_values 10 100
```
This will save the results in `--output_dir`:
```bash
# cat scifact/results.json | json_pp
{
"map" : {
"MAP@10" : 0.53105,
"MAP@100" : 0.53899
},
"mrr" : {
"MRR@10" : 0.54623
},
"ndcg" : {
"NDCG@10" : 0.57078,
"NDCG@100" : 0.60891
},
"precicion" : {
"P@10" : 0.07667,
"P@100" : 0.0097
},
"recall" : {
"Recall@10" : 0.6765,
"Recall@100" : 0.85533
}
}
```
34 changes: 20 additions & 14 deletions gpl/toolkit/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,22 @@
import logging
import numpy as np
import json
from typing import List
import argparse
from .sbert import load_sbert
logger = logging.getLogger(__name__)


def evaluate(data_path, output_dir, model_name_or_path, max_seq_length, score_function, pooling):
def evaluate(
data_path: str,
output_dir: str,
model_name_or_path: str,
max_seq_length: int = 350,
score_function: str = 'dot',
pooling: str = None,
sep: str = ' ',
k_values: List[int] = [10,]
):
data_paths = []
if 'cqadupstack' in data_path:
data_paths = [os.path.join(data_path, sub_dataset) for sub_dataset in \
Expand All @@ -27,15 +38,8 @@ def evaluate(data_path, output_dir, model_name_or_path, max_seq_length, score_fu
for data_path in data_paths:
corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")

if '0_Transformer' in os.listdir(model_name_or_path):
model_name_or_path = os.path.join(model_name_or_path, '0_Transformer')
word_embedding_model = sentence_transformers.models.Transformer(model_name_or_path, max_seq_length=max_seq_length)
pooling_model = sentence_transformers.models.Pooling(
word_embedding_model.get_word_embedding_dimension(),
pooling
)
model = sentence_transformers.SentenceTransformer(modules=[word_embedding_model, pooling_model])
sbert = models.SentenceBERT(sep=' ')
model = load_sbert(model_name_or_path, pooling, max_seq_length)
sbert = models.SentenceBERT(sep=sep)
sbert.q_model = model
sbert.doc_model = model

Expand All @@ -45,7 +49,7 @@ def evaluate(data_path, output_dir, model_name_or_path, max_seq_length, score_fu
results = retriever.retrieve(corpus, queries)

#### Evaluate your retrieval using NDCG@k, MAP@K ...
ndcg, _map, recall, precision = EvaluateRetrieval.evaluate(qrels, results, [10,])
ndcg, _map, recall, precision = EvaluateRetrieval.evaluate(qrels, results, k_values)
mrr = EvaluateRetrieval.evaluate_custom(qrels, results, [10,], metric='mrr')
ndcgs.append(ndcg)
_maps.append(_map)
Expand All @@ -71,13 +75,15 @@ def evaluate(data_path, output_dir, model_name_or_path, max_seq_length, score_fu
}, f, indent=4)
logger.info(f'Saved evaluation results to {result_path}')

if __name__ == '__name__':
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_path')
parser.add_argument('--output_dir')
parser.add_argument('--model_name_or_path')
parser.add_argument('--max_seq_length', type=int)
parser.add_argument('--score_function', choices=['dot', 'cos_sim'])
parser.add_argument('--pooling', choices=['dot', 'cos_sim'])
parser.add_argument('--sep', type=str, default=' ', help="Separation token between title and body text for each passage. The concatenation way is `sep.join([title, body])`")
parser.add_argument('--k_values', nargs='+', type=int, default=[10,], help="The K values in the evaluation. This will compute nDCG@K, recall@K, precision@K and MAP@K")
args = parser.parse_args()

evaluate(args.data_path, args.output_dir, args.model_name_or_path, args.max_seq_length, args.score_function)
evaluate(**vars(args))
5 changes: 3 additions & 2 deletions gpl/toolkit/mine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

class NegativeMiner(object):

def __init__(self, generated_path, prefix, retrievers=['bm25', 'msmarco-distilbert-base-v3', 'msmarco-MiniLM-L-6-v3'], nneg=50):
def __init__(self, generated_path, prefix, retrievers=['bm25', 'msmarco-distilbert-base-v3', 'msmarco-MiniLM-L-6-v3'], nneg=50, sep=' '):
self.corpus, self.gen_queries, self.gen_qrels = GenericDataLoader(generated_path, prefix=prefix).load(split="train")
self.output_path = os.path.join(generated_path, 'hard-negatives.jsonl')
self.sep = sep
self.retrievers = retrievers
if 'bm25' in retrievers:
assert nneg <= 10000, "Only `negatives_per_query` <= 10000 is acceptable by Elasticsearch-BM25"
Expand All @@ -27,7 +28,7 @@ def __init__(self, generated_path, prefix, retrievers=['bm25', 'msmarco-distilbe
self.nneg = len(self.corpus)

def _get_doc(self, did):
return ' '.join([self.corpus[did]['title'], self.corpus[did]['text']])
return self.sep.join([self.corpus[did]['title'], self.corpus[did]['text']])

def _mine_sbert(self, model_name):
logger.info(f'Mining with {model_name}')
Expand Down
13 changes: 6 additions & 7 deletions gpl/toolkit/mnrl.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.train import TrainRetriever
from sentence_transformers import SentenceTransformer, losses, models
from .sbert import load_sbert
import os


def mnrl(data_path, base_ckpt, output_dir, max_seq_length=350, use_amp=True):
prefix = "gen"
# TODO: `sep` argument
def mnrl(data_path, base_ckpt, output_dir, max_seq_length=350, use_amp=True, qgen_prefix='qgen', pooling=None):
#### Training on Generated Queries ####
corpus, gen_queries, gen_qrels = GenericDataLoader(data_path, prefix=prefix).load(split="train")
corpus, gen_queries, gen_qrels = GenericDataLoader(data_path, prefix=qgen_prefix).load(split="train")

#### Provide any HuggingFace model and fine-tune from scratch
model_name = base_ckpt
word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
model = load_sbert(base_ckpt, pooling, max_seq_length)

#### Provide any sentence-transformers model path
retriever = TrainRetriever(model=model, batch_size=75)

#### Prepare training samples
#### The current version of QGen training in BeIR use fixed separation token ' ' !!!
train_samples = retriever.load_train(corpus, gen_queries, gen_qrels)
train_dataloader = retriever.prepare_train(train_samples, shuffle=True)
train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model)
Expand Down
4 changes: 2 additions & 2 deletions gpl/toolkit/pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ def hard_negative_collate_fn(batch):

class PseudoLabeler(object):

def __init__(self, generated_path, gen_queries, corpus, total_steps, batch_size, cross_encoder):
def __init__(self, generated_path, gen_queries, corpus, total_steps, batch_size, cross_encoder, sep=' '):
assert 'hard-negatives.jsonl' in os.listdir(generated_path)
fpath_hard_negatives = os.path.join(generated_path, 'hard-negatives.jsonl')
self.cross_encoder = CrossEncoder(cross_encoder)
hard_negative_dataset = HardNegativeDataset(fpath_hard_negatives, gen_queries, corpus)
hard_negative_dataset = HardNegativeDataset(fpath_hard_negatives, gen_queries, corpus, sep)
self.hard_negative_dataloader = DataLoader(hard_negative_dataset, shuffle=True, batch_size=batch_size, drop_last=True)
self.hard_negative_dataloader.collate_fn = hard_negative_collate_fn
self.output_path = os.path.join(generated_path, 'gpl-training-data.tsv')
Expand Down
2 changes: 2 additions & 0 deletions gpl/toolkit/qgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import argparse


# TODO: `sep` argument
def qgen(data_path, output_dir, generator_name_or_path='BeIR/query-gen-msmarco-t5-base-v1', ques_per_passage=3, bsz=32, qgen_prefix='qgen'):
#### Provide the data_path where nfcorpus has been downloaded and unzipped
corpus = GenericDataLoader(data_path).load_corpus()

#### question-generation model loading
#### The current version of QGen in BeIR use fixed separation token ' ' !!!
generator = QGen(model=QGenModel(generator_name_or_path))

#### Query-Generation using Nucleus Sampling (top_k=25, top_p=0.95) ####
Expand Down
16 changes: 11 additions & 5 deletions gpl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def train(
gpl_steps: int = 140000,
use_amp: bool = False,
retrievers: List[str] = ['msmarco-distilbert-base-v3', 'msmarco-MiniLM-L-6-v3'],
negatives_per_query: int = 50
negatives_per_query: int = 50,
sep:str = ' ',
k_values: List[int] = [10,]
):
#### Assertions ####
assert pooling in [None, 'mean', 'cls', 'max']
Expand Down Expand Up @@ -79,7 +81,7 @@ def train(
logger.info('Using exisiting hard-negative data')
else:
logger.info('No hard-negative data found. Now mining it')
miner = NegativeMiner(path_to_generated_data, qgen_prefix, retrievers=retrievers, nneg=negatives_per_query)
miner = NegativeMiner(path_to_generated_data, qgen_prefix, retrievers=retrievers, nneg=negatives_per_query, sep=sep)
miner.run()


Expand All @@ -89,7 +91,7 @@ def train(
logger.info('Using existing GPL-training data')
else:
logger.info('No GPL-training data found. Now generating it via pseudo labeling')
pseudo_labeler = PseudoLabeler(path_to_generated_data, gen_queries, corpus, gpl_steps, batch_size_gpl, cross_encoder)
pseudo_labeler = PseudoLabeler(path_to_generated_data, gen_queries, corpus, gpl_steps, batch_size_gpl, cross_encoder, sep)
pseudo_labeler.run()


Expand All @@ -102,7 +104,7 @@ def train(
model: SentenceTransformer = load_sbert(base_ckpt, pooling, max_seq_length)

fpath_gpl_data = os.path.join(path_to_generated_data, 'gpl-training-data.tsv')
train_dataset = GenerativePseudoLabelingDataset(fpath_gpl_data, gen_queries, corpus)
train_dataset = GenerativePseudoLabelingDataset(fpath_gpl_data, gen_queries, corpus, sep)
train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=batch_size_gpl, drop_last=True) # Here shuffle=False, since (or assuming) we have done it in the pseudo labeling
train_loss = MarginDistillationLoss(model=model)

Expand Down Expand Up @@ -131,7 +133,9 @@ def train(
ckpt_dir,
max_seq_length,
score_function='dot', # Since for now MarginMSE only works with dot-product
pooling=pooling
pooling=pooling,
sep=sep,
k_values=k_values
)


Expand All @@ -156,5 +160,7 @@ def train(
parser.add_argument('--use_amp', action='store_true', default=False, help='Whether to use half precision')
parser.add_argument('--retrievers', nargs='+', default=['msmarco-distilbert-base-v3', 'msmarco-MiniLM-L-6-v3'], help='Indicate retriever names. They could be one or many BM25 ("bm25") or dense retrievers (in SBERT format).')
parser.add_argument('--negatives_per_query', type=int, default=50, help="Mine how many negatives per query per retriever")
parser.add_argument('--sep', type=str, default=' ', help="Separation token between title and body text for each passage. The concatenation way is `sep.join([title, body])`")
parser.add_argument('--k_values', nargs='+', type=int, default=[10,], help="The K values in the evaluation. This will compute nDCG@K, recall@K, precision@K and MAP@K")
args = parser.parse_args()
train(**vars(args))
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="gpl",
version="0.0.7",
version="0.0.8",
author="Kexin Wang",
author_email="kexin.wang.2049@gmail.com",
description="GPL is an unsupervised domain adaptation method for training dense retrievers. It is based on query generation and pseudo labeling with powerful cross-encoders. To train a domain-adapted model, it needs only the unlabeled target corpus and can achieve significant improvement over zero-shot models.",
Expand Down

0 comments on commit 8724d2d

Please sign in to comment.