-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
196 lines (159 loc) · 7.28 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
from argparse import ArgumentParser
import os
import git
import torch as ch
import cox
import cox.utils
import cox.store
from exp_library.model_utils import make_and_restore_model, \
check_experiment_status, \
model_dataset_from_store
from exp_library.datasets import DATASETS
from exp_library.decoupled_train import train_model, eval_model
from exp_library.tools import constants, helpers
from exp_library import defaults, __version__
from exp_library.defaults import check_and_fill_args
from exp_library.loaders import DuplicateLoader
from exp_library.pytorch_modelsize import SizeEstimator
from torch.nn.utils import parameters_to_vector as flatten
def log_norm(store, mod, log_info):
curr_params = flatten(mod.parameters())
log_info_custom = { 'epoch': log_info['epoch'],
'weight_norm': ch.norm(curr_params).detach().cpu().numpy() }
store['custom'].append_row(log_info_custom)
parser = ArgumentParser()
parser = defaults.add_args_to_parser(defaults.CONFIG_ARGS, parser)
parser = defaults.add_args_to_parser(defaults.MODEL_LOADER_ARGS, parser)
parser = defaults.add_args_to_parser(defaults.TRAINING_ARGS, parser)
parser = defaults.add_args_to_parser(defaults.PGD_ARGS, parser)
# parser = defaults.add_args_to_parser([['weight_decay_cl', float, 'weight decay classifier', 0.0001]], parser)
# parser = defaults.add_args_to_parser([['lr_cl', float, 'learning rate of the classifier', 0.0001]], parser)
extra_args = [
['weight-decay-cl', float, 'weight decay classifier', 0.0001],
['lr-cl', float, 'learning rate of the classifier', 0.0001],
['cifar-imb', float, 'imbalance factor for cifar', -1],
['entr-reg', [0, 1], 'penalize entropic outputs', 0],
['inner-batch-factor', int, 'inner batch factor', 2],
['class-balanced', [0, 1], 'class-awere sampler for classifier', 0]
]
parser = defaults.add_args_to_parser(extra_args, parser)
def main():
args = parser.parse_args()
args = cox.utils.Parameters(args.__dict__)
#first check whether exp_id already exists
is_training = False
checkpoint = None
model = None
exp_dir_path = os.path.join(args.out_dir, args.exp_name) if args.exp_name else None
if os.path.exists(exp_dir_path):
is_training = check_experiment_status(args.out_dir, args.exp_name)
if is_training and (not args.resume or args.eval_only):
s = cox.store.Store(args.out_dir, args.exp_name)
model, checkpoint, _, store, args = model_dataset_from_store(s,
overwrite_params={}, which='last', mode='a', parallel=True)
else:
args = setup_args(args)
store = setup_store_with_metadata(args)
final_model = main_worker(args, model=model, checkpoint=checkpoint, store=store)
def main_worker(args, model, checkpoint, store):
'''Given arguments from `setup_args` and a store from `setup_store`,
trains as a model. Check out the argparse object in this file for
argument options.
'''
# MAKE DATASET AND LOADERS
data_path = os.path.expandvars(args.data)
dataset = DATASETS[args.dataset](data_path)
subset = None
if 'cifar' in args.dataset and args.cifar_imb > 0:
from custom_fuctions import get_imb_subset
if args.dataset == 'cifar':
from torchvision.datasets import CIFAR10
targets = CIFAR10(data_path).targets
subset = get_imb_subset(targets, args.cifar_imb, 'cifar10')
else:
from torchvision.datasets import CIFAR100
targets = CIFAR100(data_path).targets
subset = get_imb_subset(targets, args.cifar_imb, 'cifar100')
train_loader, val_loader = dataset.make_loaders(args.workers,
args.batch_size, data_aug=bool(args.data_aug),
subset=subset)
# args.duplicates = 3
# train_loader = DuplicateLoader(train_loader, args.duplicates)
#inner_batch_factor = 2# if args.dataset == 'cifar' else 2
#args.inner_batch_factor = inner_batch_factor
class_loader, _ = dataset.make_loaders(args.workers,
args.batch_size * args.inner_batch_factor,
data_aug=bool(args.data_aug), class_sampler=args.class_balanced)
train_loader = helpers.DataPrefetcher(train_loader)
val_loader = helpers.DataPrefetcher(val_loader)
class_loader = helpers.DataPrefetcher(class_loader)
loaders = (train_loader, class_loader, val_loader)
# loaders = (train_loader, val_loader)
# MAKE MODEL
model, checkpoint = make_and_restore_model(args, dataset=dataset)
if 'module' in dir(model): model = model.module
#print(args)
# check for entr reg
if args.entr_reg:
def reg_loss(model, inp, targ):
out, _ = model(inp)
prob = ch.softmax(out, dim=1)
return -(ch.log(prob) * prob).sum(dim=1).mean()
args.regularizer = reg_loss
if args.eval_only:
return eval_model(args, model, val_loader, store=store)
root_model = model.model
feats_net_pars = list(root_model.conv1.parameters()) + list(root_model.bn1.parameters())
feats_net_pars+= list(root_model.layer1.parameters())
feats_net_pars+= list(root_model.layer2.parameters())
feats_net_pars+= list(root_model.layer3.parameters())
feats_net_pars+= list(root_model.layer4.parameters())
# give to the main optim only the data_net
model = train_model(args, model, loaders, store=store, update_params=feats_net_pars)
return model
def setup_args(args):
'''
Fill the args object with reasonable defaults from
:mod:`robustness.defaults`, and also perform a sanity check to make sure no
args are missing.
'''
# override non-None values with optional config_path
if args.config_path:
args = cox.utils.override_json(args, args.config_path)
ds_class = DATASETS[args.dataset]
args = check_and_fill_args(args, defaults.CONFIG_ARGS, ds_class)
if not args.eval_only:
args = check_and_fill_args(args, defaults.TRAINING_ARGS, ds_class)
if args.adv_train or args.adv_eval:
args = check_and_fill_args(args, defaults.PGD_ARGS, ds_class)
args = check_and_fill_args(args, defaults.MODEL_LOADER_ARGS, ds_class)
args = check_and_fill_args(args, extra_args, ds_class)
if args.eval_only: assert args.resume is not None, \
"Must provide a resume path if only evaluating"
return args
def setup_store_with_metadata(args):
'''
Sets up a store for training according to the arguments object. See the
argparse object above for options.
'''
# Add git commit to args
try:
repo = git.Repo(path=os.path.dirname(os.path.realpath(__file__)),
search_parent_directories=True)
version = repo.head.object.hexsha
except git.exc.InvalidGitRepositoryError:
version = __version__
args.version = version
# Create the store
store = cox.store.Store(args.out_dir, args.exp_name)
args_dict = args.__dict__
schema = cox.store.schema_from_dict(args_dict)
store.add_table('metadata', schema)
store['metadata'].append_row(args_dict)
# custom logs
CUSTOM_SCHEMA = {'epoch': int, 'weight_norm': float }
store.add_table('custom', CUSTOM_SCHEMA)
args.epoch_hook = log_norm
return store
if __name__ == "__main__":
main()