Skip to content

Latest commit

 

History

History
108 lines (79 loc) · 5.76 KB

README.md

File metadata and controls

108 lines (79 loc) · 5.76 KB

VAE Implementation in pytorch with visualizations

This repository implements a simple VAE for training on CPU on the MNIST dataset and provides ability to visualize the latent space, entire manifold as well as visualize how numbers interpolate between each other.

The purpose of this project is to get a better understanding of VAE by playing with the different parameters and visualizations.

VAE Tutorial Videos

VAE Understanding Implementing VAE

Architecture

Quickstart

  • Create a new conda environment with python 3.8 then run below commands
  • git clone https://github.com/explainingai-code/Pytorch-VAE.git
  • cd Pytorch-VAE
  • pip install -r requirements.txt
  • For running a simple fc layer backed VAE with latent dimension as 2 run python run_simple_vae.py
  • For playing around with VAE and running visualizations, replace tools/train_vae.py and tools/inference.py config argument with the desired one or pass that in the next set of commands
  • python -m tools.train_vae
  • python -m tools.inference

Configurations

  • config/vae_nokl.yaml - VAE with only reconstruction loss
  • config/vae_kl.yaml - VAE with reconstruction and KL loss
  • config/vae_kl_latent4.yaml - VAE with reconstruction and KL loss with latent dimension as 4(instead of 2)
  • config/vae_kl_latent4_enc_channel_dec_fc_condition.yaml - Conditional VAE with reconstruction and KL loss with latent dimension as 4

Data preparation

We don't use the torchvision mnist dataset to allow replacement with any other image dataset.

For setting up the dataset:

Verify the data directory has the following structure:

Pytorch-VAE/data/train/images/{0/1/.../9}
	*.png
Pytorch-VAE/data/test/images/{0/1/.../9}
	*.png

Output

Outputs will be saved according to the configuration present in yaml files.

For every run a folder of task_name key in config will be created and output_train_dir will be created inside it.

During training the following output will be saved

  • Best Model checkpoints in task_name directory
  • PCA information in pickle file in task_name directory
  • 2D Latent space plotting the images of test set for each epoch in task_name/output_train_dir directory

During inference the following output will be saved

  • Reconstructions for sample of test set in task_name/output_train_dir/reconstruction.png
  • Decoder output for sample of points evenly spaced across the projection of latent space on 2D in task_name/output_train_dir/manifold.png
  • Interpolation between two randomly sampled points in task_name/output_train_dir/interp directory

Sample Output for VAE

Latent Visualization

Manifold

Reconstruction Images(reconstruction in black font and original in white font)

Sample Output for Conditional VAE

Because we end up passing the label to the decoder, the model ends up learning the capability to generate ALL numbers from all points in the latent space.

The model will learn to distinguish points in latent space based on if it should generate a left or right tilted digit or how thick the stroke for digit should be. Below one can visulize those patterns when we attempt to generate all numbers from all points.

Reconstruction Images(reconstruction in black font and original in white font)