Skip to content

Commit

Permalink
Merge pull request #276 from plyfager/ada-support
Browse files Browse the repository at this point in the history
[Feature] support ada module and training
  • Loading branch information
plyfager authored Apr 2, 2022
2 parents 0c5643e + 68ab951 commit 6b8ac42
Show file tree
Hide file tree
Showing 15 changed files with 1,615 additions and 17 deletions.
6 changes: 3 additions & 3 deletions configs/styleganv3/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Models:
In Collection: StyleGANv3
Metadata:
Training Data: FFHQ
Name: stylegan3_gamma32.8
Name: stylegan3_noaug
Results:
- Dataset: FFHQ
Metrics:
Expand All @@ -24,7 +24,7 @@ Models:
In Collection: StyleGANv3
Metadata:
Training Data: Others
Name: stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8
Name: stylegan3_ada
Results:
- Dataset: Others
Metrics:
Expand All @@ -37,7 +37,7 @@ Models:
In Collection: StyleGANv3
Metadata:
Training Data: FFHQ
Name: stylegan3_gamma2.0
Name: stylegan3_t
Results:
- Dataset: FFHQ
Metrics:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
_base_ = [
'../_base_/models/stylegan/stylegan3_base.py',
'../_base_/datasets/ffhq_flip.py', '../_base_/default_runtime.py'
]

synthesis_cfg = {
'type': 'SynthesisNetwork',
'channel_base': 65536,
'channel_max': 1024,
'magnitude_ema_beta': 0.999,
'conv_kernel': 1,
'use_radial_filters': True
}
r1_gamma = 3.3 # set by user
d_reg_interval = 16

load_from = 'https://download.openmmlab.com/mmgen/stylegan3/stylegan3_r_ffhq_1024_b4x8_cvt_official_rgb_20220329_234933-ac0500a1.pth' # noqa

# ada settings
aug_kwargs = {
'xflip': 1,
'rotate90': 1,
'xint': 1,
'scale': 1,
'rotate': 1,
'aniso': 1,
'xfrac': 1,
'brightness': 1,
'contrast': 1,
'lumaflip': 1,
'hue': 1,
'saturation': 1
}

model = dict(
type='StaticUnconditionalGAN',
generator=dict(
out_size=1024,
img_channels=3,
rgb2bgr=True,
synthesis_cfg=synthesis_cfg),
discriminator=dict(
type='ADAStyleGAN2Discriminator',
in_size=1024,
input_bgr2rgb=True,
data_aug=dict(type='ADAAug', aug_pipeline=aug_kwargs, ada_kimg=100)),
gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'),
disc_auxiliary_loss=dict(loss_weight=r1_gamma / 2.0 * d_reg_interval))

imgs_root = 'data/metfaces/images/'
data = dict(
samples_per_gpu=4,
train=dict(dataset=dict(imgs_root=imgs_root)),
val=dict(imgs_root=imgs_root))

ema_half_life = 10. # G_smoothing_kimg

ema_kimg = 10
ema_nimg = ema_kimg * 1000
ema_beta = 0.5**(32 / max(ema_nimg, 1e-8))

custom_hooks = [
dict(
type='VisualizeUnconditionalSamples',
output_dir='training_samples',
interval=5000),
dict(
type='ExponentialMovingAverageHook',
module_keys=('generator_ema', ),
interp_mode='lerp',
interp_cfg=dict(momentum=ema_beta),
interval=1,
start_iter=0,
priority='VERY_HIGH')
]

inception_pkl = 'work_dirs/inception_pkl/metface_1024x1024_noflip.pkl'
metrics = dict(
fid50k=dict(
type='FID',
num_images=50000,
inception_pkl=inception_pkl,
inception_args=dict(type='StyleGAN'),
bgr2rgb=True))

evaluation = dict(
type='GenerativeEvalHook',
interval=dict(milestones=[100000], interval=[10000, 5000]),
metrics=dict(
type='FID',
num_images=50000,
inception_pkl=inception_pkl,
inception_args=dict(type='StyleGAN'),
bgr2rgb=True),
sample_kwargs=dict(sample_model='ema'))

lr_config = None

total_iters = 160000
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
_base_ = [
'../_base_/models/stylegan/stylegan3_base.py',
'../_base_/datasets/ffhq_flip.py', '../_base_/default_runtime.py'
]

synthesis_cfg = {
'type': 'SynthesisNetwork',
'channel_base': 32768,
'channel_max': 512,
'magnitude_ema_beta': 0.999
}
r1_gamma = 6.6 # set by user
d_reg_interval = 16

load_from = 'https://download.openmmlab.com/mmgen/stylegan3/stylegan3_t_ffhq_1024_b4x8_cvt_official_rgb_20220329_235113-db6c6580.pth' # noqa
# ada settings
aug_kwargs = {
'xflip': 1,
'rotate90': 1,
'xint': 1,
'scale': 1,
'rotate': 1,
'aniso': 1,
'xfrac': 1,
'brightness': 1,
'contrast': 1,
'lumaflip': 1,
'hue': 1,
'saturation': 1
}

model = dict(
type='StaticUnconditionalGAN',
generator=dict(
out_size=1024,
img_channels=3,
rgb2bgr=True,
synthesis_cfg=synthesis_cfg),
discriminator=dict(
type='ADAStyleGAN2Discriminator',
in_size=1024,
input_bgr2rgb=True,
data_aug=dict(type='ADAAug', aug_pipeline=aug_kwargs, ada_kimg=100)),
gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'),
disc_auxiliary_loss=dict(loss_weight=r1_gamma / 2.0 * d_reg_interval))

imgs_root = 'data/metfaces/images/'
data = dict(
samples_per_gpu=4,
train=dict(dataset=dict(imgs_root=imgs_root)),
val=dict(imgs_root=imgs_root))

ema_half_life = 10. # G_smoothing_kimg

ema_kimg = 10
ema_nimg = ema_kimg * 1000
ema_beta = 0.5**(32 / max(ema_nimg, 1e-8))

custom_hooks = [
dict(
type='VisualizeUnconditionalSamples',
output_dir='training_samples',
interval=5000),
dict(
type='ExponentialMovingAverageHook',
module_keys=('generator_ema', ),
interp_mode='lerp',
interp_cfg=dict(momentum=ema_beta),
interval=1,
start_iter=0,
priority='VERY_HIGH')
]

inception_pkl = 'work_dirs/inception_pkl/metface_1024x1024_noflip.pkl'
metrics = dict(
fid50k=dict(
type='FID',
num_images=50000,
inception_pkl=inception_pkl,
inception_args=dict(type='StyleGAN'),
bgr2rgb=True))

evaluation = dict(
type='GenerativeEvalHook',
interval=dict(milestones=[80000], interval=[10000, 5000]),
metrics=dict(
type='FID',
num_images=50000,
inception_pkl=inception_pkl,
inception_args=dict(type='StyleGAN'),
bgr2rgb=True),
sample_kwargs=dict(sample_model='ema'))

lr_config = None

total_iters = 160000
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,14 @@
inception_args=dict(type='StyleGAN'),
bgr2rgb=True))

inception_path = None
evaluation = dict(
type='GenerativeEvalHook',
interval=10000,
metrics=dict(
type='FID',
num_images=50000,
inception_pkl=inception_pkl,
inception_args=dict(type='StyleGAN', inception_path=inception_path),
inception_args=dict(type='StyleGAN'),
bgr2rgb=True),
sample_kwargs=dict(sample_model='ema'))

Expand Down
18 changes: 18 additions & 0 deletions mmgen/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,13 @@ def summary(self):
return self._result_dict

def extract_features(self, images):
"""Extracting image features.
Args:
images (torch.Tensor): Images tensor.
Returns:
torch.Tensor: Vgg16 features of input images.
"""
if self.use_tero_scirpt:
feature = self.vgg16(images, return_features=True)
else:
Expand Down Expand Up @@ -1278,6 +1285,17 @@ def summary(self):
return ppl_score

def get_sampler(self, model, batch_size, sample_model):
"""Get sampler for sampling along the path.
Args:
model (nn.Module): Generative model.
batch_size (int): Sampling batch size.
sample_model (str): Which model you want to use. ['ema',
'orig']. Defaults to 'ema'.
Returns:
Object: A sampler for calculating path length regularization.
"""
if sample_model == 'ema':
generator = model.generator_ema
else:
Expand Down
Loading

0 comments on commit 6b8ac42

Please sign in to comment.