Implementation of Generating Diverse High-Fidelity Images with VQ-VAE-2 in PyTorch
- Python >= 3.6
- PyTorch >= 1.1
- lmdb (for storing extracted codes)
Checkpoint of VQ-VAE pretrained on FFHQ
Currently supports 256px (top/bottom hierarchical prior)
- Stage 1 (VQ-VAE)
python train_vqvae.py [DATASET PATH]
If you use FFHQ, I highly recommends to preprocess images. (resize and convert to jpeg)
- Extract codes for stage 2 training
python extract_code.py --ckpt checkpoint/[VQ-VAE CHECKPOINT] --name [LMDB NAME] [DATASET PATH]
- Stage 2 (PixelSNAIL)
python train_pixelsnail.py [LMDB NAME]
Maybe it is better to use larger PixelSNAIL model. Currently model size is reduced due to GPU constraints.
Note: This is a training sample