-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_resnet.py
110 lines (91 loc) · 4.24 KB
/
train_resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import torch
import torch.nn as nn
import argparse
import timm
import numpy as np
import utils
import random
def train():
parser = argparse.ArgumentParser()
parser.add_argument('--data', '-d', type=str)
parser.add_argument('--gpu', '-g', default = '0', type=str)
parser.add_argument('--save_path', '-s', type=str)
parser.add_argument('--noise_rate', '-n', type=float, default=0.2)
parser.add_argument('--linear', action='store_true', default=False)
args = parser.parse_args()
config = utils.read_conf('conf/'+args.data+'.json')
device = 'cuda:'+args.gpu
save_path = os.path.join(config['save_path'], args.save_path)
data_path = config['id_dataset']
batch_size = int(config['batch_size'])
max_epoch = int(config['epoch'])
noise_rate = args.noise_rate
if not os.path.exists(save_path):
os.mkdir(save_path)
lr_decay = [int(0.5*max_epoch), int(0.75*max_epoch), int(0.9*max_epoch)]
if args.data == 'ham10000':
train_loader, valid_loader = utils.get_noise_dataset(data_path, noise_rate=noise_rate, batch_size = batch_size)
elif args.data == 'aptos':
train_loader, valid_loader = utils.get_aptos_noise_dataset(data_path, noise_rate=noise_rate, batch_size = batch_size)
elif args.data == 'idrid':
train_loader, valid_loader = utils.get_idrid_noise_dataset(data_path, noise_rate=noise_rate, batch_size = batch_size)
elif 'mnist' in args.data:
train_loader, valid_loader = utils.get_mnist_noise_dataset(args.data, noise_rate=noise_rate, batch_size = batch_size)
model = timm.create_model('resnet50', pretrained = True, num_classes = config['num_classes'])
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
model.eval()
# optimizer = torch.optim.SGD(model.parameters(), lr = 0.01, momentum=0.9, weight_decay = 1e-05)
params = model.fc.parameters() if args.linear else model.parameters()
optimizer = torch.optim.Adam(params, lr=1e-3, weight_decay = 1e-5)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, lr_decay)
saver = timm.utils.CheckpointSaver(model, optimizer, checkpoint_dir= save_path, max_history = 1)
print(train_loader.dataset[0][0].shape, args.linear)
avg_accuracy = 0.0
for epoch in range(max_epoch):
## training
model.train()
total_loss = 0
total = 0
correct = 0
for batch_idx, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
# outputs = model(inputs)
if args.linear:
with torch.no_grad():
features = model.forward_features(inputs)
features = model.global_pool(features)
outputs = model.fc(features)
else:
outputs = model(inputs)
# outputs = model.linear(outputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
total_loss += loss
total += targets.size(0)
_, predicted = outputs[:len(targets)].max(1)
correct += predicted.eq(targets).sum().item()
print('\r', batch_idx, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (total_loss/(batch_idx+1), 100.*correct/total, correct, total), end = '')
train_accuracy = correct/total
train_avg_loss = total_loss/len(train_loader)
print()
## validation
model.eval()
total_loss = 0
total = 0
correct = 0
valid_accuracy = utils.validation_accuracy_resnet(model, valid_loader, device)
scheduler.step()
if epoch >= max_epoch-10:
avg_accuracy += valid_accuracy
saver.save_checkpoint(epoch, metric = valid_accuracy)
print('EPOCH {:4}, TRAIN [loss - {:.4f}, acc - {:.4f}], VALID [acc - {:.4f}]\n'.format(epoch, train_avg_loss, train_accuracy, valid_accuracy))
print(scheduler.get_last_lr())
with open(os.path.join(save_path, 'avgacc.txt'), 'w') as f:
f.write(str(avg_accuracy/10))
if __name__ =='__main__':
train()