This is a straightforward implementation of ProGAN, a pioneering approach proposed by Nvidia. It is specifically designed to enable the training of models that progressively generate images, starting from low resolution and gradually increasing to higher resolutions. With this implementation, you have the flexibility to utilize your own dataset for training purposes.
The underlying theory and technical details of ProGAN are comprehensively outlined in the thesis titled "Progressive Growing of GANs for Improved Quality, Stability, and Variation." This thesis delves into the methodology of progressively growing Generative Adversarial Networks (GANs) to enhance the quality, stability, and variation of the generated images, making it a valuable resource for understanding the intricacies of ProGAN.
The directory structure should be the following:
/ProGAN
|--/data
|--/images
|--/save
|--/workdir
|--dataset.py
|--network.py
|--setting.py
|--test.py
|--train.py
/data
: This directory should include your dataset. Place images directly into it without any additional substructures./images
: This directory is used to generate sample images during training, allowing the user to monitor the training process./save
: Models are saved here during training. Save the model after every epoch./workdir
: This directory is where the current model is located.dataset.py
: This is a tool for loading data from the/data
directory.network.py
: This file defines the structure of the network.setting.py
: This file contains parameters for network and training.test.py
: This is a tool for testing the dataloader and model.train.py
: This file is used to train the model.
To use this reposory, you just need to clone it.
git clone https://github.com/Keyan0412/ProGAN.git
Firstly, ensure all your images are placed in the /data
folder without any additional subfolders. Subsequently, configure the necessary parameters in setting.py
. Once done, you can proceed to start the training by executing train.py
. To understand how to execute it, run the command:
python train.py -h
This will display the following usage information:
usage: train.py [-h] -s STEP [-e EPOCH] [-f] [-n]
Train your model on your own dataset
options:
-h, --help show this help message and exit
-s STEP, --step STEP step of training (default: None)
-e EPOCH, --epoch EPOCH
number of epoch (default: 20)
-f, --fade using fade in (default: False)
-n, --new train new model (default: False)
In the above information, step
refers to the stage of training. The image size for training will be 4*2^step
. If you wish to train a new model instead of loading an existing one, include the -n
flag. If you are training for a new step, use the -f
flag to enable fade in.
To understand how to execute testing, run the command:
python test.py -h
This will display the following usage information:
usage: test.py [-h] -s STEP
Test your model and dataset
options:
-h, --help show this help message and exit
-s STEP, --step STEP step of training (default: None)
This will generate two images in the /images
folder named test_*.png
: one generated by the existing model and the other obtained from a sample. To run this file, ensure "model.pth.tar" exists in /workdir
, and you have set the correct parameters in setting.py
.
When training, a warning might appear:
/data/miniconda/envs/torch/lib/python3.10/site-packages/torch/autograd/graph.py:768: UserWarning: Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context... (Triggered internally at /opt/conda/conda-bld/pytorch_1720165264854/work/aten/src/ATen/cuda/CublasHandlePool.cpp:135.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Upon verification, this warning does not affect the training speed.
Furthermore, when setting the batch size to 1, another warning may arise due to the mini-batch standard deviation block:
/data/coding/network.py:152: UserWarning: std(): degrees of freedom is <= 0. Correction should be strictly less than the reduction factor (input numel divided by output numel). (Triggered internally at /opt/conda/conda-bld/pytorch_1720165264854/work/aten/src/ATen/native/ReduceOps.cpp:1808.)
torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])