-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
134 lines (106 loc) · 4.89 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from utils import *
import time
from tqdm.notebook import tqdm
import wandb
import torch.nn as nn
def fit(device, epoch, stop_epoch, min_val_loss, model, model_path, train_loader, val_loader, optimizer, scheduler, patch=False):
criterion = nn.CrossEntropyLoss().to(device)
wandb.watch(model, criterion, log='all', log_freq=10)
train_losses, val_losses = [], []
val_iou, train_iou, val_acc, train_acc = [], [], [], []
loss_notdecrease_count = 0
prev_e_loss = min_val_loss
# model.to(device)
fit_time = time.time()
for e in range(epoch):
since = time.time()
t_loss_perb, t_iou_perb, t_acc_perb = 0, 0, 0
v_loss_perb, v_iou_perb, v_acc_perb = 0, 0, 0
t_loss_pere, v_loss_pere = 0, 0
#training loop
model.train()
for i, data in enumerate(tqdm(train_loader)):
#training phase
image_tiles, mask_tiles = data
if patch:
bs, n_tiles, c, h, w = image_tiles.size()
image_tiles = image_tiles.view(-1,c, h, w)
mask_tiles = mask_tiles.view(-1, h, w)
image, mask = image_tiles.to(device), mask_tiles.to(device)
#forward
output = model(image)
loss = criterion(output, mask)
#evaluation metrics
t_iou_perb += mIoU(output, mask)
t_acc_perb += pixel_accuracy(output, mask)
#backward
loss.backward()
optimizer.step() #update weight
optimizer.zero_grad()#reset gradient
#step the learning rate
scheduler.step()
t_loss_perb += loss.item()
torch.cuda.empty_cache()
with torch.no_grad():
model.eval()
#validation loop
for i, data in enumerate(tqdm(val_loader)):
#reshape to 9 patches from single image, delete batch size
image_tiles, mask_tiles = data
if patch:
bs, n_tiles, c, h, w = image_tiles.size()
image_tiles = image_tiles.view(-1,c, h, w)
mask_tiles = mask_tiles.view(-1, h, w)
image, mask = image_tiles.to(device), mask_tiles.to(device)
output = model(image)
#evaluation metrics
v_iou_perb += mIoU(output, mask)
v_acc_perb += pixel_accuracy(output, mask)
#loss
loss = criterion(output, mask)
v_loss_perb += loss.item()
torch.cuda.empty_cache()
#calculation mean for each batch
t_loss_pere = t_loss_perb/len(train_loader)
train_losses.append(t_loss_pere)
v_loss_pere = v_loss_perb/len(val_loader)
val_losses.append(v_loss_pere)
checkpoint = {
'epoch': e + 1,
'train_loss': t_loss_pere,
'val_loss': v_loss_pere,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict()
}
save_checkpoint(checkpoint, f'{model_path}/checkpoint.pth')
print('save model...')
if min_val_loss > v_loss_pere:
print('Loss Decreasing.. {:.3f} >> {:.3f} '.format(min_val_loss, v_loss_pere))
min_val_loss = v_loss_pere
save_checkpoint(checkpoint, f'{model_path}/checkpoint.pth',True, f'{model_path}/best.pth')
else:
if v_loss_pere > prev_e_loss:
loss_notdecrease_count+=1
print(f'Loss Not Decrease for {loss_notdecrease_count} time')
if loss_notdecrease_count==stop_epoch:
print(f'Loss not decrease for {stop_epoch} times, Stop Training')
break
#iou
val_iou.append(v_iou_perb/len(val_loader))
train_iou.append(t_iou_perb/len(train_loader))
train_acc.append(t_acc_perb/len(train_loader))
val_acc.append(v_acc_perb/ len(val_loader))
print("Epoch:{}/{}..".format(e+1, epoch),
"Train Loss: {:.3f}..".format(t_loss_pere),
"Val Loss: {:.3f}..".format(v_loss_pere),
"Train Score:{:.3f}..".format(t_iou_perb/len(train_loader)),
"Val Score: {:.3f}..".format(v_iou_perb/len(val_loader)),
"Train Acc:{:.3f}..".format(t_acc_perb/len(train_loader)),
"Val Acc:{:.3f}..".format(v_acc_perb/len(val_loader)),
"Time: {:.2f}m".format((time.time()-since)/60))
history = {'train_loss' : train_losses, 'val_loss': val_losses,
'train_score' :train_iou, 'val_score':val_iou,
'train_acc' :train_acc, 'val_acc':val_acc}
print('Total time: {:.2f} m' .format((time.time()- fit_time)/60))
return history