This is the official pytorch
implementation of the NeurIPS 2022 paper Wasserstein Iterative Networks for Barycenter Estimation (paper on openreview) by Alexander Korotin, Vahe Egizarian, Lingxiao Li and Evgeny Burnaev.
The repository contains the code for the proposed (1) iterative neural algorithm to estimate the Wasserstein-2 barycenter at a large-scale and the code to produce the (2) Ave, celeba! dataset which can be used to benchmark the continuous barycenter algorithms.
- Lightning talk by Alexander Korotin at NeurIPS 2022 (December 2022, EN);
- Talk by Evgeny Burnaev at AIRI workshop (15 December 2022, RU);
The implementation is GPU-based. In the exeriments, we use from 1 to 4 GPUs 1080ti. Tested with
PyTorch== 1.9.0
The code might not run as intended in the other torch
versions.
- Repository for Do Neural Optimal Transport Solvers Work? A Continuous Wasserstein-2 Benchmark paper.
- Repository for Continuous Wasserstein-2 Barycenter Estimation without Minimax Optimization paper.
All the experiments are issued in the form of pretty self-explanatory jupyter notebooks (notebooks/
). Auxilary source code is moved to .py
modules (src/
). Input-convex neural networks needed to produce Ave, celeba! are stored as .pt
checkpoints (benchmarks/
).
Pretrained models for our barycenter algorithm link: [Yandex disk (part 1)] [Yandex disk (part 2)]
Precomputed stats for FID: [Yandex disk]
notebooks/WIN_plots.ipynb
- notebook to visualize the results of the trained models;
notebooks/WIN_location_scatter.ipynb
- learning barycenters in the location-scatter case;
notebooks/WIN_barycenters.ipynb
- learning barycenters of image datasets (Algorithm 1);
notebooks/WIN_barycenters_invert.ipynb
- learning inverse maps for image datasets (Algorithm 2);
The dataset consists of three non-intersecting degraded subsets (67K+67K+67K images in total) of 64x64 celebrity faces dataset (202K images). The degradations are constructed by our proposed methodology which ensures that the barycenter of the 3 full degraded sets (202K+202K+202K images) are the original clean celebrity faces w.r.t. weights (0.25, 0.5, 0.25).
Ready-to-use Ave, celeba! 64x64 images dataset link: [Yandex disk]
datasets/ave_celeba.ipynb
- notebook to produce the Ave, celeba! dataset from scratch;
Warning! It is highly recommend to use the provided dataset rather than recreate it from scratch. This is due to issue with inconsistencies with random seeds in different pytroch versions. It might lead to different produced images.
@inproceedings{
korotin2022wasserstein,
title={Wasserstein Iterative Networks for Barycenter Estimation},
author={Alexander Korotin and Vage Egiazarian and Lingxiao Li and Evgeny Burnaev},
booktitle={Thirty-Sixth Conference on Neural Information Processing Systems},
year={2022},
url={https://openreview.net/forum?id=GiEnzxTnaMN}
}
- Weights & Biases developer tools for machine learning;
- CelebA page with faces dataset and this page with its aligned 64x64 version;
- pytorch-fid repo to compute FID score;
- UNet architecture for transporter network;
- ResNet architectures for generator and discriminator;
- iGAN repository for the datasets of handbags & shoes;
- Fruit 360 kaggle for the dataset of fruits.