forked from gsarridis/InDistill
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_baselines.py
142 lines (107 loc) · 4.94 KB
/
train_baselines.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import argparse
import torch.nn as nn
import torch.optim as optim
from utils.retrieval_evaluation import evaluate_model_retrieval
from utils.loaders import cifar10_loader
import torch
from utils.utils import *
from models.cnn32 import Cnn32
from models.cnn32_small import Cnn32_Small
from models.resnet import ResNet18
def train(net, optimizer, loss_fn, loader, epochs, device):
for epoch in range(epochs):
net.train()
train_loss = 0
correct = 0
total = 0
for (inputs, targets) in tqdm(loader):
inputs= inputs.to(device)
targets = targets.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = loss_fn(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.data.item()
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum().item()
acc = correct/total
print(f"epoch: {epoch}, loss = {train_loss}, accuracy = {acc}")
return net
def test(net, loss_fn, loader, device):
net.eval()
test_loss = 0
correct = 0
total = 0
for (inputs, targets) in tqdm(loader):
inputs= inputs.to(device)
targets = targets.to(device)
outputs = net(inputs)
loss = loss_fn(outputs, targets)
test_loss += loss.data.item()
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum().item()
acc = correct/total
print(f"test loss = {test_loss}, accuracy = {acc}")
def train_model(net, train_loader, test_loader, learning_rates, epochs, save_path, device):
loss_fn = nn.CrossEntropyLoss()
for lr, ep in zip(learning_rates, epochs):
print(f'current lr: {lr}')
optimizer = optim.Adam(net.parameters(), lr=lr)
net = train(net, optimizer, loss_fn, train_loader, epochs=ep, device=device)
test(net, loss_fn, test_loader,device=device)
torch.save(net.state_dict(), save_path)
return net
def metric_learning_eval(net, dataset, filename):
if dataset == 'cifar10':
loader = cifar10_loader
evaluate_model_retrieval(net=net, path='', dataset_loader=loader,
result_path='./results/scores/'+filename+'_baseline_'+dataset+'_retrieval.pickle', layer=3)
evaluate_model_retrieval(net=net, path='', dataset_loader=loader,
result_path='./results/scores/'+filename+'_baseline_'+dataset+'_retrieval_e.pickle', layer=3, metric='l2')
def arg_parser():
parser = argparse.ArgumentParser()
# Add an argument
parser.add_argument('--batch_size', type=int, required=False, default=128)
parser.add_argument('--lr', nargs='+', type=float, required=False, default=[0.001,0.0001], help='Learning rate value')
parser.add_argument('--ep', nargs='+', type=int, default=[50,30], required=False, help='Number of epochs')
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10'])
parser.add_argument('--device', type=str, default='cuda:0')
# Parse the argument
args = parser.parse_args()
return args
def main():
_ = ensure_reproducability()
# Create the parser
args = arg_parser()
device = torch.device(args.device)
# Load Dataset and models
if args.dataset == 'cifar10':
train_loader, test_loader, _ = cifar10_loader(batch=args.batch_size)
teacher_net = ResNet18(num_classes=10)
aux_net = Cnn32(num_classes=10, input_channels=3)
student_net = Cnn32_Small(num_classes=10, input_channels=3)
# transfer to device
teacher_net = teacher_net.to(device)
aux_net = aux_net.to(device)
student_net = student_net.to(device)
# train teacher
print("Training teacher model...")
teacher_net = train_model(teacher_net, train_loader= train_loader, test_loader= test_loader, learning_rates=args.lr, epochs=args.ep,
save_path='./results/models/teacher_baseline_'+args.dataset+'.pt', device=device)
# train auxiliary
print("Training auxiliary model...")
aux_net = train_model(aux_net, train_loader= train_loader, test_loader= test_loader, learning_rates=args.lr, epochs=args.ep,
save_path='./results/models/auxiliary_baseline_'+args.dataset+'.pt', device=device)
# train student
print("Training student model...")
student_net = train_model(student_net, train_loader= train_loader, test_loader= test_loader, learning_rates=args.lr, epochs=args.ep,
save_path='./results/models/student_baseline_'+args.dataset+'.pt', device=device)
# metric learning evaluation
metric_learning_eval(teacher_net, args.dataset, 'teacher')
metric_learning_eval(aux_net, args.dataset, 'auxiliary')
metric_learning_eval(student_net, args.dataset, 'student')
if __name__ == '__main__':
main()