-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathadv_train.py
160 lines (127 loc) · 4.37 KB
/
adv_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
Run this script to train a ConvNet on MNIST.
"""
from functools import partial
import torch
from torch import optim
from sacred import Experiment, SETTINGS
from sacred.observers import FileStorageObserver
from visdom_observer.visdom_observer import VisdomObserver
import pytorch_utils.sacred_trainer as st
from pytorch_utils.updaters import averager
from model_ingredient import model_ingredient, make_model
from data_ingredient import data_ingredient, make_dataloaders
from training_functions import train_on_batch, create_val_scheduler_callback
from adversarial import fgsm, pgd, AdversarialLoader
torch.backends.cudnn.benchmark = True
SETTINGS.CAPTURE_MODE = 'no'
ex = Experiment('adv-odenet_mnist_randtime',
ingredients=[model_ingredient, data_ingredient])
SAVE_DIR = 'runs/AdvODEnetRandTimeMnist'
ex.observers.append(FileStorageObserver.create(SAVE_DIR))
ex.observers.append(VisdomObserver())
ATTACKS = {
'fgsm':fgsm,
'pgd':pgd
}
# ------- COMBINING TRAIN AND ADVERSARIAL LOADERS -------
class CombineDataloaders:
"""Combine mutliple dataloaders or iterators which yield images and labels
into one iterator.
Parameters
----------
*loaders : type
List of dataloaders to combine.
"""
def __init__(self, *loaders):
self.loaders = loaders
def __iter__(self):
iters = [iter(loader) for loader in self.loaders]
while True:
items = [next(it) for it in iters]
images = torch.cat([i[0] for i in items])
labels = torch.cat([i[1] for i in items])
yield images, labels
def __len__(self):
return min([len(loader) for loader in self.loaders])
# ----------------OPTIMIZER-----------------
@ex.config
def optimizer_config():
"""Config for optimzier
Currently available opts (types of optimizers):
adam
adamax
rmsprop
"""
lr = 0.001 # learning rate
opt = 'adam' # type of optimzier
weight_decay = 0 # l2 regularization weight_decay (lambda)
@ex.capture
def make_optimizer(model, lr, opt, weight_decay):
"""Make an optimizer of the given type (opt), for the given model's
parameters with the given learning rate (lr)"""
optimizers = {
'adam':optim.Adam,
'adamax':optim.Adamax,
'rmsprop':optim.RMSprop,
}
optimizer = optimizers[opt](model.parameters(), lr=lr,
weight_decay=weight_decay)
return optimizer
# -----------CALLBACK FOR LR SCHEDULING-------------
@ex.config
def scheduler_config():
"""Config for lr scheduler"""
milestones = [50, 100]
gamma = 0.5 # factor to reduce lr by at each milestone
@ex.capture
def make_scheduler_callback(optimizer, milestones, gamma):
"""Create a MultiStepLR scheduler callback for the optimizer
using the config"""
return create_val_scheduler_callback(optimizer, milestones, gamma)
@ex.config
def train_config():
epochs = 100
save_every = 1
start_epoch = 1
@ex.config
def attack_config():
attack = 'pgd'
epsilon = 0.3
pgd_step_size = 0.01
pgd_num_steps = 40
pgd_random_start = True
@ex.automain
def main(_run,
attack,
epsilon,
pgd_step_size,
pgd_num_steps,
pgd_random_start,):
dset, train, val, test = make_dataloaders()
model = make_model()
optimizer = make_optimizer(model)
callback = make_scheduler_callback(optimizer)
if attack == 'pgd':
attack_fn = partial(ATTACKS[attack],
epsilon=epsilon,
step_size=pgd_step_size,
num_steps=pgd_num_steps,
random_start=pgd_random_start)
else:
attack_fn = partial(ATTACKS[attack], epsilon=epsilon)
adv_train = AdversarialLoader(model, train, attack_fn)
final_train = CombineDataloaders(train, adv_train)
st.loop(
**{**_run.config,
**dict(_run=_run,
model=model,
optimizer=optimizer,
save_dir=SAVE_DIR,
trainOnBatch=train_on_batch,
train_loader=final_train,
val_loader=val,
callback=callback,
callback_metric_names=['val_loss', 'val_acc', 'learning_rate'],
batch_metric_names=['loss', 'acc', 'nfef', 'nfeb'],
updaters=[averager]*4)})