d2sample
is a collection of PyTorch differentiable samplers for discrete objects, with associated examples and layers (TBD). Install by first cloning recursively:
git clone --recursive git@github.com:sscardapane/d2sample.git
Then (for now) add to the path:
import sys
sys.path.append('./d2sample/')
- PyTorch == 1.13.1
- functorch
$k$ -subset sampling (see notebook):
- Gumbel-Softmax with continuous top-$k$ relaxation (Xie & Ermon, 2019). For
$k=1$ this reduces to the standard Gumbel-Softmax reparameterization available inside PyTorch. - Top-k selection with I-MLE (Niepert, Minervini, & Franceschi, 2021).
- SIMPLE: Subset Implicit Likelihood Estimation (Ahmed, Zeng, Niepert, & Van den Broeck, 2022).