✨ Information Preservation with Wasserstein Autoencoders: Generation Consistency and Adversarial Robustness
This repository contains code and resources for training Wasserstein Autoencoders (WAE) to explore generation fidelity and adversarial robustness. The project focuses on concurrent density estimation using different latent space distributions and various activation functions.
Before running any experiments, ensure you have all the required Python packages installed. You can install them with:
pip install -r requirements.txt
You can customize the training process with the following arguments:
--groupsort
: Use Groupsort activation function. Default is0
. Set to1
to enable.--js
: Use Jensen-Shannon Divergence. Default is0
(MMD is selected by default). Set to1
to enable.--beta
: Change latent space distribution to Beta. Default is0
(Gaussian latent space is selected).--exp
: Change latent space distribution to Exponential. Default is0
(Gaussian latent space is selected).--gauss
: Opt for Gaussian Ball distribution.--mnist
: Use MNIST dataset.
To start training, use the following command:
python3 train.py
To run experiments on the Swiss Roll dataset, navigate to the ./swiss_role
directory and execute:
python3 code_v6.py
Ensure that dataset.py
is in the same folder. Modify it as needed for your experiments.
Robust codes for Gaussian Ball and MNIST datasets are provided in the Robust
folder. These have been tested with Cauchy, Dirichlet, and Gaussian noise. To run robust experiments on the Swiss Roll dataset, navigate to ./Robust/swiss_role/
and run:
python3 robust_code.py
To experiment with the Gaussian Ball position and increase the number of clusters, modify the dataset.py
file.
We recommend setting a portion of 0.2 for MMD and JS integration with the reconstruction loss. You can adjust this ratio in config.py
.
We encourage open collaboration on hyperparameter settings. All configurations are available in config.py
for easy modification and experimentation.
The model.py
file contains a simple dense neural network model for MNIST and Gaussian Ball reconstruction.
In robustness testing, datasets are mixed with specific ratios:
- Gaussian Ball: See lines 127, 129 in
cauchy.py
. - Dirichlet: See lines 134, 137 in
dirichlet.py
.
Adjust the ratios on lines 95 and 101 respectively. Similar modifications can be made for the MNIST dataset.