-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathconfig.yml
309 lines (275 loc) · 9.07 KB
/
config.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
shared:
upscale: &upscale 4 # 2, 4, 8
patch_size: &patch_size 128 # 40, 64, 96, 128, 192
model:
_key_value: true
&generator_model generator:
_target_: esrgan.models.EncoderDecoderNet
encoder:
_target_: esrgan.models.ESREncoder
in_channels: &num_channels 3
out_channels: &latent_channels 64
num_basic_blocks: 16
growth_channels: 32
activation: &activation
_mode_: partial
_target_: torch.nn.LeakyReLU
negative_slope: 0.2
inplace: true
residual_scaling: 0.2
decoder:
_target_: esrgan.models.ESRNetDecoder
in_channels: *latent_channels
out_channels: *num_channels
scale_factor: *upscale
activation: *activation
&discriminator_model discriminator:
_target_: esrgan.models.VGGConv
encoder:
_target_: esrgan.models.StridedConvEncoder
pool:
_target_: catalyst.contrib.layers.AdaptiveAvgPool2d
output_size: [7,7]
head:
_target_: esrgan.models.LinearHead
in_channels: 25088 # 512 * (7x7)
out_channels: 1
latent_channels: [1024]
args:
logdir: logs
runner:
_target_: esrgan.runner.GANConfigRunner
generator_key: *generator_model
discriminator_key: *discriminator_model
stages:
stage1_supervised:
num_epochs: 10000
loaders: &loaders
train: &train_loader
_target_: torch.utils.data.DataLoader
dataset:
_target_: torch.utils.data.ConcatDataset
datasets:
- &div2k_dataset
_target_: esrgan.datasets.DIV2KDataset
root: data
train: true
target_type: bicubic_X4
patch_size: [*patch_size,*patch_size]
transform:
_target_: albumentations.Compose
transforms:
- &spatial_transforms
_target_: albumentations.Compose
transforms:
_target_: albumentations.OneOf
transforms:
- _target_: albumentations.Flip
p: 0.75 # p = 1/4 (vflip) + 1/4 (hflip) + 1/4 (flip)
- _target_: albumentations.Transpose
p: 0.25 # p = 1/4
p: 0.5
additional_targets:
real_image: image
- &hard_transforms
_target_: albumentations.Compose
transforms:
- _target_: albumentations.CoarseDropout
max_holes: 8
max_height: 2
max_width: 2
- _target_: albumentations.ImageCompression
quality_lower: 65
p: 0.25
- &post_transforms
_target_: albumentations.Compose
transforms:
- _target_: albumentations.Normalize
mean: 0
std: 1
- _target_: albumentations.pytorch.ToTensorV2
additional_targets:
real_image: image
low_resolution_image_key: image
high_resolution_image_key: real_image
download: true
- &flickr2k_dataset
<< : [*div2k_dataset] # Flickr2K with the same params as in `DIV2KDataset`
_target_: esrgan.datasets.Flickr2KDataset
batch_size: 16
shuffle: true
num_workers: 8
pin_memory: true
drop_last: true
valid:
<< : [*train_loader]
dataset: # redefine dataset to use only DIV2K
<< : [*div2k_dataset]
train: false
transform: *post_transforms
batch_size: 1
drop_last: false
criterion: &criterions
content_loss:
# `torch.nn.L1Loss` | `torch.nn.MSELoss`
_target_: torch.nn.L1Loss
optimizer:
_key_value: true
generator:
_target_: torch.optim.Adam
lr: 0.0002
weight_decay: 0.0
_model: *generator_model
scheduler:
_key_value: true
generator:
_target_: torch.optim.lr_scheduler.StepLR
step_size: 500
gamma: 0.5
_optimizer: generator
callbacks: &callbacks
psnr_metric:
_target_: catalyst.callbacks.FunctionalMetricCallback
metric_fn:
_target_: piq.psnr
data_range: 1.0
reduction: mean
convert_to_greyscale: false
input_key: real_image
target_key: fake_image
metric_key: psnr
ssim_metric:
_target_: catalyst.callbacks.FunctionalMetricCallback
metric_fn:
_target_: piq.ssim
kernel_size: 11
kernel_sigma: 1.5
data_range: 1.0
reduction: mean
k1: 0.01
k2: 0.03
input_key: real_image
target_key: fake_image
metric_key: ssim
loss_content:
_target_: catalyst.callbacks.CriterionCallback
input_key: real_image
target_key: fake_image
metric_key: loss_content
criterion_key: content_loss
optimizer_generator:
_target_: catalyst.callbacks.OptimizerCallback
metric_key: loss_content
model_key: *generator_model
optimizer_key: generator
grad_clip_fn: &grad_clip_fn
_mode_: partial
_target_: torch.nn.utils.clip_grad_value_
clip_value: 5.0
scheduler_generator:
_target_: catalyst.callbacks.SchedulerCallback
scheduler_key: generator
loader_key: valid
metric_key: loss_content
stage2_gan:
num_epochs: 8000
loaders:
<< : [*loaders]
train:
<< : [*train_loader]
dataset:
<< : [*div2k_dataset]
transform:
_target_: albumentations.Compose
transforms:
- *spatial_transforms
- *post_transforms
batch_size: 16
criterion:
<< : [*criterions]
perceptual_loss:
_target_: esrgan.nn.PerceptualLoss
layers:
conv5_4: 1.0
adversarial_generator_loss:
# `esrgan.nn.RelativisticAdversarialLoss` | `esrgan.nn.AdversarialLoss`
_target_: &adversarial_criterion esrgan.nn.RelativisticAdversarialLoss
mode: generator
adversarial_discriminator_loss:
_target_: *adversarial_criterion
mode: discriminator
optimizer:
_key_value: true
generator:
_target_: torch.optim.AdamW
lr: 0.0001
weight_decay: 0.0
_model: *generator_model
discriminator:
_target_: torch.optim.AdamW
lr: 0.0001
weight_decay: 0.0
_model: *discriminator_model
scheduler:
_key_value: true
generator:
_target_: torch.optim.lr_scheduler.MultiStepLR
milestones: &scheduler_milestones [1000,2000,4000,6000]
gamma: 0.5
_optimizer: generator
discriminator:
_target_: torch.optim.lr_scheduler.MultiStepLR
milestones: *scheduler_milestones
gamma: 0.5
_optimizer: discriminator
callbacks:
# re-use `psnr_metric`, `ssim_metric`, and `loss_content` callbacks
<< : [*callbacks]
loss_perceptual:
_target_: catalyst.callbacks.CriterionCallback
input_key: real_image
target_key: fake_image
metric_key: loss_perceptual
criterion_key: perceptual_loss
loss_adversarial:
_target_: catalyst.callbacks.CriterionCallback
input_key: g_fake_logits # first argument of criterion is fake_logits
target_key: g_real_logits # second argument of criterion is real_logits
metric_key: loss_adversarial
criterion_key: adversarial_generator_loss
loss_generator:
_target_: catalyst.callbacks.MetricAggregationCallback
metric_key: &generator_loss loss_generator
metrics:
loss_content: 0.01
loss_perceptual: 1.0
loss_adversarial: 0.005
mode: weighted_sum
loss_discriminator:
_target_: catalyst.callbacks.CriterionCallback
input_key: d_fake_logits
target_key: d_real_logits
metric_key: &discriminator_loss loss_discriminator
criterion_key: adversarial_discriminator_loss
optimizer_generator:
_target_: catalyst.callbacks.OptimizerCallback
metric_key: *generator_loss
model_key: *generator_model
optimizer_key: generator
grad_clip_fn: *grad_clip_fn
optimizer_discriminator:
_target_: catalyst.callbacks.OptimizerCallback
metric_key: *discriminator_loss
model_key: *discriminator_model
optimizer_key: discriminator
grad_clip_fn: *grad_clip_fn
scheduler_generator:
_target_: catalyst.callbacks.SchedulerCallback
scheduler_key: generator
loader_key: valid
metric_key: *generator_loss
scheduler_discriminator:
_target_: catalyst.callbacks.SchedulerCallback
scheduler_key: discriminator
loader_key: valid
metric_key: *discriminator_loss