Variational Autoencoders were first described in the paper:
- "Auto-encoding variational Bayes" by Kingma and Welling, (here)
Some great tutorials on the Variational Autoencoder can be found in the papers:
- "Tutorial on variational autoencoders" by Carl Doersch, (here)
- "An introduction to variational autoencoders" by Kingma and Welling, (here)
A very simple and useful implementation of an Autoencoder and a Variational autoencoder can be found in this blog post. The autoencoders are trained on MNIST and some cool visualizations of the latent space are shown.
The equation that is at the core of the variational autoencoder is:
The left hand side has the quantity that we want to optimize:
Autoencoders are trained to encode input data into a smaller feature vector,
and afterwards reconstruct it into the original input. In general, an autoencoder
consists of an encoder that maps the input
Suppose that both the encoder and decoder architectures have only one hidden layer without any non-linearity (linear autoencoder). In this case we can see a clear connection with PCA, in the sense that we are looking for the best linear subspace to project the data on. In general, both the encoder and the decoder are deep non-linear networks, and thus inputs are encoded into a much more complex subspace.
Once we have a trained autoencoder we can use the encoder to "compress" inputs
from our high-dimensional input space into the low-dimensional latent space. And
we can also use the decoder to decompress them back into the high-dimensional
input space. But there is no convenient way to generate any new data points from
our input space. In order to generate a new data point we need to sample a
feature vector from the latent space and decode it afterwards. However there is
no convenient way to choose "good" samples from our latent space. To solve this
problem, variational autoencoders introduce a new idea for compression. Instead
of encoding the input as a single feature vector, the variational autoencoder
will encode the input as a probability distribution over tha latent space. In
practice, the encoded distribution is chosen to be a multivariate normal
distribution with a diagonal covariance matrix. Thus, the output of the encoder
component will be a vector of (
The mathematics behind Variational Autoencoders actually has very little to do with classical autoencoders. They are called "autoencoders" only because the architecture does have an encoder and a decoder and resembles a traditional autoencoder.
First of all, what we want to achieve is to produce a model that can generate
data points from the space of our training data. To do this we will assume that
there is some latent space
The marginal probability of
and our objective is to maximize the likelihood of the training data:
Using this approach we expect the model to learn to decode nearby latents to similar things.
If we choose the latent space
However, one problem arises in this setting. In order to optimize our model we
actually want to sample
What we want to do is actually sample from a different distribution. One that is
much more likely to yield a useful value of
Now, using both models for the marginal probability of a given data point
And to optimize the model we simply have to maximize this quantity by optimizing
the parameters
However, in addition to maximizing the probability of our data, we would also
want the distribution
From here we can see that we are left with the same objective only this time we
optimize over the parameters
Finally, our objective is:
This objective is known as the Variational Lower Bound (VLB) and in this form
there is a very intuitive interpretation: The first term is the reconstruction
loss and encourages the decoder to learn to reconstruct the data. The second
term is a regularizer that tries to push the distribution produced by the encoder
towards the prior
In case
where
Also note that there is a closed formula for the KL divergence between two
multivariate Gaussian distributions. In our case
where
One problem with our current objective is that we need to sample from
This is actually a very well studied problem in reinforcement learning where we
want to compute the derivative of
In general, the derivative of
is given by:
The problem with this formula is that we actually need to take a lot of samples in order to reduce the variance of the computation, otherwise the computed gradient will be very noise and might not result in any meaningful updates.
Another approach to solve this problem is the so-called "reparametrization
trick". Note that we are parametrizing
And, thus, our objective becomes:
Computing the derivative can now be done by simply sampling
We will train the model on the CIFAR-10 dataset. For the encoder-decoder structure we will use a U-Net with the contraction path corresponding to the encoder and the expansion path corresponding to the decoder.
For the contraction path we will employ a standard ResNet with three separate groups of blocks, every time reducing the spatial dimensions in half and doubling the number of channels. In every group there are four residual blocks: the first block will downscale the input, and the other three blocks are standard residual blocks and will operate on the same scale, i.e., number of channels and spatial dimensions remain fixed. The expansion path will use the same ResNet, but instead the first block of each group will upscale the image, doubling the spatial size and reducing the channels in half.
For more on the ResNet check out a blog post I wrote on the topic.
To train the model simply run:
python3 run.py --seed 0 --epochs 25
The script will download the CIFAR-10 dataset into a datasets
folder and will
train the model on it. The trained model parameters will be saved to the file
vae.pt
.
During model training we track how well it reconstructs images from their latent representations. The images used for reconstruction are drawn from the validation set. A visualization of the reconstructed images for all training epochs is shown below. We can see that after a few epochs the model learns well how to reconstruct the images.
Finally, to use the trained model to generate images run the following:
vae = torch.load("vae.pt")
imgs = vae.sample(n=49) # imgs.shape = (49, 3, 32, 32)
grid = torchvision.utils.make_grid(imgs, nrow=7)
plt.imshow(grid.permute(1, 2, 0))
This is what the model generates after training for 50 epochs.