-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
178 lines (140 loc) · 6.67 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# Copyright (C) 2019 National Center of Tumor Diseases (NCT) Dresden, Division of Translational Surgical Oncology
from train_opts import parser
from models import TSN
from dataset import TSNDataSet
from transforms import GroupNormalize
from eval import splits_LOSO, splits_LOUO, splits_LOUO_NP
from util import AverageMeter
import util
import os.path
import datetime
import string
import torch
def main():
global args
if not torch.cuda.is_available():
print("GPU not found - exit")
return
device_gpu = torch.device("cuda:0")
if len([t for t in string.Formatter().parse(args.data_path)]) > 1:
args.data_path = args.data_path.format(args.task)
if len([t for t in string.Formatter().parse(args.video_lists_dir)]) > 1:
args.video_lists_dir = args.video_lists_dir.format(args.task)
output_folder = os.path.join(args.out, args.exp + "_" + datetime.datetime.now().strftime("%Y%m%d"),
args.eval_scheme, str(args.split), datetime.datetime.now().strftime("%H%M"))
os.makedirs(output_folder)
f_log = open(os.path.join(output_folder, "log.txt"), "w")
def log(msg):
util.log(f_log, msg)
log("Used parameters...")
for arg in sorted(vars(args)):
log("\t" + str(arg) + " : " + str(getattr(args, arg)))
# ===== set up model =====
consensus_type = 'avg'
model = TSN(args.num_class, args.num_segments, args.modality, new_length=args.snippet_length,
consensus_type=consensus_type, before_softmax=True, dropout=args.dropout, partial_bn=False,
use_three_input_channels=args.three_channel_flow, pretrained_model=args.pretrain_path)
# freeze weights
# if args.arch == 'Pretrained-Inception-v3':
for param in model.base_model.parameters():
param.requires_grad = False
for param in model.base_model.fc_action.parameters():
param.requires_grad = True
for name, module in model.base_model.named_modules():
if name.startswith("mixed_10"):
for param in module.parameters():
param.requires_grad = True
# elif args.arch == '3D-Resnet-34':
# for param in model.base_model.parameters():
# param.requires_grad = False
# for i in range(0, 3):
# block = getattr(model.base_model.layer4, str(i))
# for param in block.parameters():
# param.requires_grad = True
# ===== set up data loader =====
splits = None
if args.eval_scheme == 'LOSO':
splits = splits_LOSO
elif args.eval_scheme == 'LOUO':
if args.task == "Needle_Passing":
splits = splits_LOUO_NP
else:
splits = splits_LOUO
assert (args.split >= 0 and args.split < len(splits))
train_lists = splits[0:args.split] + splits[args.split + 1:]
normalize = GroupNormalize(model.input_mean, model.input_std)
train_augmentation = model.get_augmentation(args.do_horizontal_flip)
lists_dir = os.path.join(args.video_lists_dir, args.eval_scheme)
train_lists = list(map(lambda x: os.path.join(lists_dir, x), train_lists))
log("Splits in train set :" + str(train_lists))
train_set = TSNDataSet(args.data_path, train_lists, num_segments=args.num_segments, new_length=args.snippet_length,
modality=args.modality, image_tmpl=args.image_tmpl, transform=train_augmentation,
normalize=normalize, random_shift=True, test_mode=False,
video_sampling_step=args.video_sampling_step, video_suffix=args.video_suffix,
return_three_channels=args.three_channel_flow,
preload_to_RAM=args.data_preloading)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)
log("Loaded {} training videos".format(train_loader.dataset.__len__()))
# ===== set up training =====
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
log("param count: {}".format(sum(p.numel() for p in model.parameters())))
log("trainable params: {}".format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
# ===== start! =====
log("Start training...")
model = model.to(device_gpu)
torch.backends.cudnn.benchmark = True
for epoch in range(0, args.epochs):
train_loss = AverageMeter()
train_acc = AverageMeter()
model.train()
i = 0
for _, batch in enumerate(train_loader):
optimizer.zero_grad()
data, target = batch
batch_size = target.size(0)
data = data.to(device_gpu)
target = target.to(device_gpu)
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss.update(loss.item(), batch_size)
_output = torch.nn.Softmax(dim=1)(output)
_, predicted = torch.max(_output.data, 1)
acc = (predicted == target).sum().item() / batch_size
train_acc.update(acc, batch_size)
print(f'This is i in line 153 {i}')
i = i + 1
if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: # eval
log("Epoch {}: Train loss: {train_loss.avg:.4f} Train acc: {train_acc.avg:.3f} ".format(
epoch, train_loss=train_loss, train_acc=train_acc))
if (epoch + 1) % args.save_freq == 0 or epoch == args.epochs - 1: # save
name = "model_" + str(epoch)
model_file = os.path.join(output_folder, name + ".pth.tar")
state = {'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
}
torch.save(state, model_file)
log("Saved model to " + model_file)
log("Done.")
f_log.close()
if __name__ == '__main__':
#print(torch.rand(1, device="cuda:1"))
args = parser.parse_args()
args.num_class = 3
args.video_suffix = "_capture2"
args.image_tmpl = 'img_{:05d}.jpg'
if args.modality == 'Flow':
args.image_tmpl = 'flow_{}_{:05d}.jpg'
if args.data_path == '?':
print("Please specify the path to your (flow) image data using the --data_path option or set an appropriate "
"default in train_opts.py!")
else:
if args.out == '?':
print("Please specify the path to your output folder using the --out option or set an appropriate default "
"in train_opts.py!")
else:
main()