This repository provides miscellaneous utilities for ranking models.
Clone this repository and run:
python -m pip install .
The dataset pre-processing script reads a dataset and creates training, dev and test sets (HDF5 format) that can be used by the ranking models. Run the script as follows to see available options:
python -m ranking_utils.scripts.create_h5_data
The following datasets are currently supported:
- ANTIQUE
- FiQA Task 2
- InsuranceQA (v2)
- MS MARCO (v1) passage and document ranking (TREC-DL test sets)
- Any dataset in generic TREC format
The script uses Hydra. Refer to the documentation for detailed instructions on how to configure arguments.
The following pre-processes the ANTIQUE dataset:
python -m ranking_utils.scripts.create_h5_data \
dataset=antique \
dataset.root_dir=/path/to/antique/files \
hydra.run.dir=/path/to/output/files
In order to see all available options for a dataset, run:
python -m ranking_utils.scripts.create_h5_data \
dataset=antique \
--help
Implementing a ranker requires two components:
- A DataProcessor (specific to your model) subclasses
ranking_utils.model.data.DataProcessor
and implements the following methods:get_model_input(self, query: str, doc: str) -> ModelInput
: Transforms a query-document pair into an input that is suitable for the model.get_model_batch(self, inputs: Iterable[ModelInput]) -> ModelBatch
: Creates a model batch from multiple inputs.
- The ranking model itself subclasses
ranking_utils.model.Ranker
and implements the following methods:forward(self, batch: ModelBatch) -> torch.Tensor
: Computes query-document scores, output shape(batch_size, 1)
.configure_optimizers(self) -> Tuple[List[Any], List[Any]]
: Configures optimizers (and schedulers). Refer to the PyTorch Lightning documentation.
You can then train your model using the usual PyTorch Lightning setup. For example:
from pathlib import Path
from pytorch_lightning import Trainer
from ranking_utils.model.data import H5DataModule
from ranking_utils.model import TrainingMode
from my_ranker import MyRanker, MyDataProcessor
data_module = H5DataModule(
data_processor=MyDataProcessor(...),
data_dir=Path(...),
fold_name="fold_0",
batch_size=32
)
model = MyRanker(...)
data_module.training_mode = model.training_mode = TrainingMode.PAIRWISE
model.margin = 0.2
Trainer(...).fit(model=model, datamodule=data_module)
The following training modes are supported:
TrainingMode.POINTWISE
TrainingMode.PAIRWISE
TrainingMode.CONTRASTIVE
After each epoch, the ranker automatically computes the following ranking metrics on the validation set:
val_RetrievalMAP
val_RetrievalMRR
val_RetrievalNormalizedDCG
These can be used in combination with callbacks, e.g. early stopping.
Similarly to validation, the following metrics can be computed on the test set (using PyTorch Lightning's testing functionality):
test_RetrievalMAP
test_RetrievalMRR
test_RetrievalNormalizedDCG
Example implementations of various models using this library can be found here.