Skip to content

mrjleo/ranking-utils

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ranking-utils

This repository provides miscellaneous utilities for ranking models.

Installation

Clone this repository and run:

python -m pip install .

Usage

Dataset Pre-Processing

The dataset pre-processing script reads a dataset and creates training, dev and test sets (HDF5 format) that can be used by the ranking models. Run the script as follows to see available options:

python -m ranking_utils.scripts.create_h5_data

The following datasets are currently supported:

The script uses Hydra. Refer to the documentation for detailed instructions on how to configure arguments.

Example

The following pre-processes the ANTIQUE dataset:

python -m ranking_utils.scripts.create_h5_data \
    dataset=antique \
    dataset.root_dir=/path/to/antique/files \
    hydra.run.dir=/path/to/output/files

In order to see all available options for a dataset, run:

python -m ranking_utils.scripts.create_h5_data \
    dataset=antique \
    --help

Ranking

Implementing a ranker requires two components:

  1. A DataProcessor (specific to your model) subclasses ranking_utils.model.data.DataProcessor and implements the following methods:
    • get_model_input(self, query: str, doc: str) -> ModelInput: Transforms a query-document pair into an input that is suitable for the model.
    • get_model_batch(self, inputs: Iterable[ModelInput]) -> ModelBatch: Creates a model batch from multiple inputs.
  2. The ranking model itself subclasses ranking_utils.model.Ranker and implements the following methods:
    • forward(self, batch: ModelBatch) -> torch.Tensor: Computes query-document scores, output shape (batch_size, 1).
    • configure_optimizers(self) -> Tuple[List[Any], List[Any]]: Configures optimizers (and schedulers). Refer to the PyTorch Lightning documentation.

You can then train your model using the usual PyTorch Lightning setup. For example:

from pathlib import Path
from pytorch_lightning import Trainer
from ranking_utils.model.data import H5DataModule
from ranking_utils.model import TrainingMode
from my_ranker import MyRanker, MyDataProcessor

data_module = H5DataModule(
    data_processor=MyDataProcessor(...),
    data_dir=Path(...),
    fold_name="fold_0",
    batch_size=32
)
model = MyRanker(...)
data_module.training_mode = model.training_mode = TrainingMode.PAIRWISE
model.margin = 0.2
Trainer(...).fit(model=model, datamodule=data_module)

The following training modes are supported:

  • TrainingMode.POINTWISE
  • TrainingMode.PAIRWISE
  • TrainingMode.CONTRASTIVE

Validation

After each epoch, the ranker automatically computes the following ranking metrics on the validation set:

  • val_RetrievalMAP
  • val_RetrievalMRR
  • val_RetrievalNormalizedDCG

These can be used in combination with callbacks, e.g. early stopping.

Testing

Similarly to validation, the following metrics can be computed on the test set (using PyTorch Lightning's testing functionality):

  • test_RetrievalMAP
  • test_RetrievalMRR
  • test_RetrievalNormalizedDCG

Examples

Example implementations of various models using this library can be found here.

Releases

No releases published

Packages

No packages published

Languages