-
Notifications
You must be signed in to change notification settings - Fork 1
/
discriminator_config.py
49 lines (39 loc) · 1007 Bytes
/
discriminator_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""module containing configurations for the model and training routine"""
from torch.nn.functional import relu, selu, leaky_relu, tanh
from torch.nn import MSELoss, L1Loss, SmoothL1Loss
from torch.optim import Adam
discriminator = {
'n_layers': 7,
'kernel_size': (3, 3),
'activation': selu,
'channel_factor': 8,
'max_channels': 1024,
'input_channels': 1,
'n_residual': (0, 0),
'affine': True
}
dataset = {
'vmin': 'mean-0.5',
'vmax': 1.0,
'whitening': False
}
dataloader = {
'batch_size': 32,
'shuffle': True,
'num_workers': 2,
}
optimizer = {
'optimizer_type': Adam,
'learning_rate': 0.005,
'apex': False,
'lr_decay': (0.99, 1e5)
}
training = {
'epochs': 25,
'name': 'discriminator-test',
}
def get_complete_config():
import config
complete_config = {key: value for key, value in config.__dict__.items()
if isinstance(value, dict) and key != '__builtins__'}
return complete_config