-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain_backbone_ablation.py
86 lines (76 loc) · 2.4 KB
/
train_backbone_ablation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import sys
import json
import datetime
import config
import itertools
from collections import defaultdict
from core.train_engine import check_config, train, configs_from_grid
ROOT_DIR = config.paths.ablation_bkb_dir
if __name__ == "__main__":
grid_many_branches = \
{
'train_args': {
'learning_rate': [0.001],
'batch_size' : [100],
'n_epochs' : [50],
'rec_loss_w' : [1.0],
'top_loss_w' : [0.0, 1.0, 10.0, 20.0, 40.0]
},
'model_args': {
'class_id' : ['DCGEncDec'],
'kwargs' : {
'filter_config' : [[3,32,64,128]],
'input_config' : [[3,32,32]],
'latent_config' : {
'n_branches' : [32,16,8],
'out_features_branch': [20,10,5]
}
}
},
'data_args':{
'dataset' : ['cifar100'],
'subset_ratio': [1.0],
'train' : [True]
}
}
grid_one_branch = \
{
'train_args': {
'learning_rate': [0.001],
'batch_size' : [100],
'n_epochs' : [50],
'rec_loss_w' : [1.0],
'top_loss_w' : [0.0, 1.0, 10.0, 20.0, 40.0]
},
'model_args': {
'class_id' : ['DCGEncDec'],
'kwargs' : {
'filter_config' : [[3,32,64,128]],
'input_config' : [[3,32,32]],
'latent_config' : {
'n_branches' : [1],
'out_features_branch': [160]
}
}
},
'data_args':{
'dataset' : ['cifar100'],
'subset_ratio': [1.0],
'train' : [True]
}
}
grids = {'grid_one_branch': grid_one_branch, 'grid_many_branches': grid_many_branches}
now = datetime.datetime.now()
# path = os.path.join(ROOT_DIR, now.strftime("%Y-%m-%d-%H-%M-%S"))
path = ROOT_DIR
os.makedirs(path)
configs = []
for k, v in grids.items():
configs += configs_from_grid(v)
with open(os.path.join(path, k + '.json'), 'w') as fid:
json.dump(v, fid)
for i,c in enumerate(configs):
print(c)
print('Config {}/{}'.format(i+1,len(configs)))
train(path, c)