Python library for the differentiable hypergeometric distribution.
This is the official code for the ICLR 2023 Paper (Spotlight) "Learning Group Importance using the Differentiable Hypergeometric Distribution".
We are still working on the code and the repository. Feedback and requests are very welcome.
We provide an environment file env_mvhg.yml
that helps you with setting up a conda environment.
For help on how to install conda, please follow the guidelines on the offical webiste (link to the offical website)
To create the conda environment needed, please run the following command
conda env create -f env_mvhg.yml
conda activate mvhg
pip install "[.pt]"
The conda environment runs on python 3.8.
We provide a minimal example, which learn the class weights from samples. The minimal example uses pytorch lightning, weights & biases, and hydra config.
In the root directory, run the following command
python main_minimal_app.py
If you use our model in your own, please cite us using the following citation
@inproceedings{sutter2023,
title={Learning Group Importance using the Differentiable Hypergeometric Distribution},
author={Sutter, Thomas M and Manduchi, Laura and Ryser, Alain and Vogt, Julia E},
year = {2023},
booktitle = {International Conference on Learning Representations},
}
For any questions or requests, please reach out to: Thomas Sutter (thomas.sutter@inf.ethz.ch)