-
Notifications
You must be signed in to change notification settings - Fork 271
Adding doc, adding a timer, fixing bugs. #84
base: main
Are you sure you want to change the base?
Changes from 11 commits
3950b0c
c4c4b87
d5f5841
8d2cd8f
c6379f9
1a07faf
5599bc2
da128a1
39e7ee9
05a8a1c
b15a162
2b13f24
b1f0983
1805b95
b3d62b0
d6577ba
dcdaa66
57fdfaf
d2a11a9
fa54887
8374753
88eca90
d0a40a6
10e2247
dcb4f4b
b43210c
4451d70
4fc893d
3dccf5e
d812b0a
6f0cec7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
A GAN toolbox for researchers and developers with: | ||
- Progressive Growing of GAN(PGAN): https://arxiv.org/pdf/1710.10196.pdf | ||
- DCGAN: https://arxiv.org/pdf/1511.06434.pdf | ||
- To come: StyleGAN https://arxiv.org/abs/1812.04948 | ||
- StyleGAN https://arxiv.org/abs/1812.04948 | ||
|
||
<img src="illustration.png" alt="illustration"> | ||
Picture: Generated samples from GANs trained on celebaHQ, fashionGen, DTD. | ||
|
@@ -48,6 +48,21 @@ pip install -r requirements.txt | |
- DTD: https://www.robots.ox.ac.uk/~vgg/data/dtd/ | ||
- CIFAR10: http://www.cs.toronto.edu/~kriz/cifar.html | ||
|
||
For a quick start with CelebAHQ, you might: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just wanted to facilitate people's life by copy-pasting the TLDR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea for celebaHQ, it was a pain in the *** to make the first time. I would however add a section ## Quick download |
||
``` | ||
git clone https://github.com/nperraud/download-celebA-HQ.git | ||
cd download-celebA-HQ | ||
conda create -n celebaHQ python=3 | ||
source activate celebaHQ | ||
conda install jpeg=8d tqdm requests pillow==3.1.1 urllib3 numpy cryptography scipy | ||
pip install opencv-python==3.4.0.12 cryptography==2.1.4 | ||
sudo apt-get install p7zip-full | ||
python download_celebA.py ./ | ||
python download_celebA_HQ.py ./ | ||
python make_HQ_images.py ./ | ||
export PATH_TO_CELEBAHQ=`readlink -f ./celebA-HQ/512` | ||
``` | ||
|
||
## Quick training | ||
|
||
The datasets.py script allows you to prepare your datasets and build their corresponding configuration files. | ||
|
@@ -64,8 +79,9 @@ And wait for a few days. Your checkpoints will be dumped in output_networks/cele | |
For celebaHQ: | ||
|
||
``` | ||
python datasets.py celebaHQ $PATH_TO_CELEBAHQ -o $OUTPUT_DATASET - f | ||
python train.py PGAN -c config_celebaHQ.json --restart -n celebaHQ | ||
python datasets.py celebaHQ $PATH_TO_CELEBAHQ -o $OUTPUT_DATASET - f # Prepare the dataset and build the configuration file. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure of this << - f >> ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. -f accelerate the training by saving intermediate images of smaller sizes. |
||
python train.py PGAN -c config_celebaHQ.json --restart -n celebaHQ # Train. | ||
python eval.py inception -n celebaHQ -m PGAN # If you want to check the inception score. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd rather not add the inception part here. |
||
``` | ||
|
||
Your checkpoints will be dumped in output_networks/celebaHQ. You should get 1024x1024 generations at the end. | ||
|
@@ -130,7 +146,8 @@ Where: | |
|
||
1 - MODEL_NAME is the name of the model you want to run. Currently, two models are available: | ||
- PGAN(progressive growing of gan) | ||
- PPGAN(decoupled version of PGAN) | ||
- DCGAN | ||
- StyleGAN | ||
|
||
2 - CONFIGURATION_FILE(mandatory): path to a training configuration file. This file is a json file containing at least a pathDB entry with the path to the training dataset. See below for more informations about this file. | ||
|
||
|
@@ -209,19 +226,19 @@ You need to use the eval.py script. | |
|
||
You can generate more images from an existing checkpoint using: | ||
``` | ||
python eval.py visualization -n $modelName -m $modelType | ||
Molugan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
python eval.py visualization -n $runName -m $modelName | ||
``` | ||
|
||
Where modelType is in [PGAN, PPGAN, DCGAN] and modelName is the name given to your model. This script will load the last checkpoint detected at testNets/$modelName. If you want to load a specific iteration, please call: | ||
Where modelName is in [PGAN, StyleGAN, DCGAN] and runName is the name given to your run (trained model). This script will load the last checkpoint detected at output_networks/$modelName. If you want to load a specific iteration, please call: | ||
|
||
``` | ||
python eval.py visualization -n $modelName -m $modelType -s $SCALE -i $ITER | ||
python eval.py visualization -n $runName -m $modelName -s $SCALE -i $ITER | ||
``` | ||
|
||
If your model is conditioned, you can ask the visualizer to print out some conditioned generations. For example: | ||
|
||
``` | ||
python eval.py visualization -n $modelName -m $modelType --Class T_SHIRT | ||
python eval.py visualization -n $runName -m $modelName --Class T_SHIRT | ||
``` | ||
|
||
Will plot a batch of T_SHIRTS in visdom. Please use the option - -showLabels to see all the available labels for your model. | ||
|
@@ -231,16 +248,21 @@ Will plot a batch of T_SHIRTS in visdom. Please use the option - -showLabels to | |
To save a randomly generated fake dataset from a checkpoint please use: | ||
|
||
``` | ||
python eval.py visualization -n $modelName -m $modelType --save_dataset $PATH_TO_THE_OUTPUT_DATASET --size_dataset $SIZE_OF_THE_OUTPUT | ||
python eval.py visualization -n $runName -m $modelName --save_dataset $PATH_TO_THE_OUTPUT_DATASET --size_dataset $SIZE_OF_THE_OUTPUT | ||
``` | ||
|
||
### SWD metric | ||
|
||
Using the same kind of configuration file as above, just launch: | ||
|
||
``` | ||
python eval.py laplacian_SWD -c $CONFIGURATION_FILE -n $modelName -m $modelType | ||
python eval.py laplacian_SWD -c $CONFIGURATION_FILE -n $runName -m $modelName | ||
``` | ||
for the SWD score, to be maximized, or for the inception score: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, the SWD should be minimized |
||
``` | ||
python eval.py inception -c $CONFIGURATION_FILE -n $runName -m $modelName | ||
``` | ||
also to be maximized (see https://hal.inria.fr/hal-01850447/document for a discussion). | ||
|
||
|
||
Where $CONFIGURATION_FILE is the training configuration file called by train.py (see above): it must contains a "pathDB" field pointing to path to the dataset's directory. For example, if you followed the instruction of the Quick Training section to launch a training session on celebaHQ your configuration file will be config_celebaHQ.json. | ||
|
||
|
@@ -252,6 +274,7 @@ You can add optional arguments: | |
|
||
### Inspirational generation | ||
|
||
An inspirational generation consists in generation with your GAN an image which looks like a given input image. | ||
To make an inspirational generation, you first need to build a feature extractor: | ||
|
||
``` | ||
|
@@ -261,14 +284,14 @@ python save_feature_extractor.py {vgg16, vgg19} $PATH_TO_THE_OUTPUT_FEATURE_EXTR | |
Then run your model: | ||
|
||
``` | ||
python eval.py inspirational_generation -n $modelName -m $modelType --inputImage $pathTotheInputImage -f $PATH_TO_THE_OUTPUT_FEATURE_EXTRACTOR | ||
python eval.py inspirational_generation -n $runName -m $modelName --inputImage $pathTotheInputImage -f $PATH_TO_THE_OUTPUT_FEATURE_EXTRACTOR | ||
``` | ||
|
||
### I have generated my metrics. How can i plot them on visdom ? | ||
|
||
Just run | ||
``` | ||
python eval.py metric_plot -n $modelName | ||
python eval.py metric_plot -n $runName | ||
``` | ||
|
||
## LICENSE | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
import os | ||
import time | ||
|
||
from .standard_configurations.pgan_config import _C | ||
from ..progressive_gan import ProgressiveGAN | ||
|
@@ -22,6 +23,7 @@ def __init__(self, | |
pathdb, | ||
miniBatchScheduler=None, | ||
datasetProfile=None, | ||
max_time=0, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @brozi I added a timer, which will help for DFO or AS for GAN optimization. |
||
configScheduler=None, | ||
**kwargs): | ||
r""" | ||
|
@@ -30,6 +32,7 @@ def __init__(self, | |
dataset | ||
- useGPU (bool): set to True if you want to use the available GPUs | ||
for the training procedure | ||
- max_time (int): max number of seconds for training (0 = infinity). | ||
- visualisation (module): if not None, a visualisation module to | ||
follow the evolution of the training | ||
- lossIterEvaluation (int): size of the interval on which the | ||
|
@@ -46,7 +49,7 @@ def __init__(self, | |
- stopOnShitStorm (bool): should we stop the training if a diverging | ||
behavior is detected ? | ||
""" | ||
|
||
self.max_time = max_time | ||
self.configScheduler = {} | ||
if configScheduler is not None: | ||
self.configScheduler = { | ||
|
@@ -208,6 +211,7 @@ def train(self): | |
+ "_train_config.json") | ||
self.saveBaseConfig(pathBaseConfig) | ||
|
||
start = time.time() | ||
for scale in range(self.startScale, n_scales): | ||
|
||
self.updateDatasetForScale(scale) | ||
|
@@ -230,7 +234,8 @@ def train(self): | |
shiftAlpha += 1 | ||
|
||
while shiftIter < self.modelConfig.maxIterAtScale[scale]: | ||
|
||
if self.max_time > 0 and time.time() - start > self.max_time: | ||
break | ||
self.indexJumpAlpha = shiftAlpha | ||
status = self.trainOnEpoch(dbLoader, scale, | ||
shiftIter=shiftIter, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So no StyleGAN implementation in the end, @Molugan ?😞
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a PR for styleGAN !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
styleGAN incoming (#95) so no need to remove this one