Skip to content

Commit 7b74b58

Browse files
committed
teacher training files are updated for a better logging
1 parent f68fc61 commit 7b74b58

File tree

6 files changed

+35
-25
lines changed

6 files changed

+35
-25
lines changed

dataset/__init__.py

Whitespace-only changes.

dataset/cifar100.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,14 @@ class CIFAR100Instance(datasets.CIFAR100):
9595

9696
def __getitem__(self, index):
9797

98-
if torch.__version__[0] == '0':
98+
# if torch.__version__[0] == '0':
9999

100-
if self.train:
101-
img, target = self.train_data[index], self.train_labels[index]
102-
else:
103-
img, target = self.test_data[index], self.test_labels[index]
100+
if self.train:
101+
img, target = self.train_data[index], self.train_labels[index]
104102
else:
105-
img, target = self.data[index], self.targets[index]
103+
img, target = self.test_data[index], self.test_labels[index]
104+
# else:
105+
# img, target = self.data[index], self.targets[index]
106106

107107
# doing this so that it is consistent with all other datasets
108108
# to return a PIL Image

models/mobilenetv2.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -162,18 +162,21 @@ def forward(self, x, is_feat=False, preact=False):
162162
def _initialize_weights(self):
163163
for m in self.modules():
164164
if isinstance(m, nn.Conv2d):
165-
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
165+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels #/ m.groups
166+
# print(m.kernel_size[0], m.kernel_size[1], m.in_channels, m.out_channels, m.groups)
166167
m.weight.data.normal_(0, math.sqrt(2. / n))
167168
if m.bias is not None:
168169
m.bias.data.zero_()
170+
169171
elif isinstance(m, nn.BatchNorm2d):
170172
m.weight.data.fill_(1)
171173
m.bias.data.zero_()
172174
elif isinstance(m, nn.Linear):
173175
n = m.weight.size(1)
174176
m.weight.data.normal_(0, 0.01)
175177
m.bias.data.zero_()
176-
178+
print("initializing done!!!")
179+
# exit()
177180

178181
def mobilenetv2_T_w(T, W, feature_dim=100):
179182
model = MobileNetV2(T=T, feature_dim=feature_dim, width_mult=W)

supermix.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from helper.util import get_teacher_name
1919
from models import model_dict
2020
import math
21+
import imageio
2122

2223

2324
def load_teacher(model_path, n_cls):
@@ -341,7 +342,7 @@ def augment(plot=True):
341342

342343
img = img.astype(np.uint8)
343344

344-
misc.imsave(save_dir + '/' + str(counter + i) + '.png', img)
345+
imageio.imwrite(save_dir + '/' + str(counter + i) + '.png', img)
345346

346347
counter += n_suc
347348

@@ -385,10 +386,10 @@ def count_parameters(model):
385386

386387
if __name__ == '__main__':
387388
parser = argparse.ArgumentParser()
388-
parser.add_argument('--path_t', type=str, default='./save/models/wrn_40_2_vanilla/ckpt_epoch_240.pth',
389+
parser.add_argument('--path_t', type=str, default='./save/models/resnet110_vanilla/ckpt_epoch_240.pth',
389390
help='teacher model snapshot')
390391
parser.add_argument('--device', type=str, default='cuda:0', help='cuda or cpu')
391-
parser.add_argument('--save_dir', type=str, default='/home/aldb/outputs/new2',
392+
parser.add_argument('--save_dir', type=str, default='/home/mehdi/output',
392393
help='output directory to save results')
393394
parser.add_argument('--bs', type=int, default=100, help='batch size for dataloader')
394395
parser.add_argument('--aug_size', type=int, default=500000, help='number of samples to generate')

train_series.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def parse_option():
8181
parser.add_argument('--hint_layer', default=2, type=int, choices=[0, 1, 2, 3, 4])
8282

8383
parser.add_argument('--test_interval', type=int, default=None, help='test interval')
84-
parser.add_argument('--seed', default=102, type=int, help='random seed')
84+
parser.add_argument('--seed', default=19, type=int, help='random seed')
8585

8686
opt = parser.parse_args()
8787

@@ -96,21 +96,21 @@ def parse_option():
9696

9797
# gamma = [0.1, 0.3, 0.5, 0.7, 0.9]
9898

99-
student_list = [8, 9, 10, 11, 12]
99+
student_list =range(10, 13)
100100

101101
k_list = [3]
102102
k_list.reverse()
103-
for k in k_list:
103+
for s in student_list:
104104
opt = parse_option()
105105
# opt.aug_size = a
106106
opt.aug_alpha = 3
107107
opt.aug_lambda = -1
108-
opt.gamma = 2
109-
opt.alpha = 0
110-
opt.aug_type = 'supermix'
111-
opt.trial = "07Feb20"
112-
s = 0
113-
opt.aug_k = k
108+
opt.gamma = 1
109+
opt.alpha = 0.5
110+
opt.aug_type = 'mixup'
111+
opt.trial = "12Feb20_originit"
112+
# s = 0
113+
opt.aug_k = 2
114114

115115

116116
if s==0:
@@ -142,9 +142,11 @@ def parse_option():
142142
elif s==8:
143143
opt.model_s = 'MobileNetV2'
144144
opt.path_t = './save/models/ResNet50_vanilla/ckpt_epoch_240.pth'
145+
opt.batch_size=64
145146
elif s==9:
146147
opt.model_s = 'vgg8'
147148
opt.path_t = './save/models/ResNet50_vanilla/ckpt_epoch_240.pth'
149+
opt.batch_size = 64
148150
elif s==10:
149151
opt.model_s = 'ShuffleV1'
150152
opt.path_t = './save/models/resnet32x4_vanilla/ckpt_epoch_240.pth'

train_teacher.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,20 @@ def parse_option():
3030
parser.add_argument('--batch_size', type=int, default=128, help='batch_size')
3131
parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use')
3232
parser.add_argument('--epochs', type=int, default=600, help='number of training epochs')
33-
parser.add_argument('--device', type=str, default='cuda:1', help='batch_size')
33+
parser.add_argument('--device', type=str, default='cuda:0', help='batch_size')
3434

3535
# optimization
36-
parser.add_argument('--learning_rate', type=float, default=0.1, help='learning rate')
36+
parser.add_argument('--learning_rate', type=float, default=0.02, help='learning rate')
3737
parser.add_argument('--lr_decay_epochs', type=str, default='200, 300, 400, 500', help='where to decay lr, can be a list')
3838
parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate')
3939
parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
4040
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
4141
parser.add_argument('--aug', type=str, default=None,
4242
help='address of the augmented dataset')
43+
parser.add_argument('--aug_type', type=str, default=None,
44+
help='address of the augmented dataset')
4345
# dataset
44-
parser.add_argument('--model', type=str, default='vgg8',
46+
parser.add_argument('--model', type=str, default='MobileNetV2',
4547
choices=['resnet8', 'resnet14', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110',
4648
'resnet8x4', 'resnet32x4', 'wrn_16_1', 'wrn_16_2', 'wrn_40_1', 'wrn_40_2',
4749
'vgg8', 'vgg11', 'vgg13', 'vgg16', 'vgg19',
@@ -53,8 +55,8 @@ def parse_option():
5355
opt = parser.parse_args()
5456

5557
# set different learning rate from these 4 models
56-
if opt.model in ['MobileNetV2', 'ShuffleV1', 'ShuffleV2']:
57-
opt.learning_rate = 0.01
58+
# if opt.model in ['MobileNetV2', 'ShuffleV1', 'ShuffleV2']:
59+
# opt.learning_rate = 0.01
5860

5961
# set the path according to the environment
6062

@@ -96,6 +98,8 @@ def main():
9698
momentum=opt.momentum,
9799
weight_decay=opt.weight_decay)
98100

101+
print("learning rate:", opt.learning_rate)
102+
99103
criterion = nn.CrossEntropyLoss()
100104

101105
if torch.cuda.is_available():

0 commit comments

Comments
 (0)