-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainloop.py
163 lines (127 loc) · 5.56 KB
/
trainloop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import copy
import os
import shutil
import pickle
import time
from optimizers import *
from main_utils import *
from MAS_utils import *
def mas_train(model,optimizer, model_criterion,task,epochs,no_of_classes,lr,scheduler_lambda,num_frozen,use_gpu,trdataload,tedataload,train_size,test_size):
"""
Training Loop of the model
"""
omega_epochs = epochs+1
store_path = os.path.join(os.getcwd(),"models","Task_"+str(task))
model_path = os.path.join(os.getcwd(),"models")
device = torch.device("cuda:0" if use_gpu else "cpu")
## Creating a directory if there is no dircetory
if(task== 1 and not os.path.isdir(model_path)):
os.mkdir(model_path)
#checkpoint_file, flag = check_checkpoints(store_path)
flag = False
if(flag == False):
create_task_dir( no_of_classes, store_path)
start_epoch =0
else:
if(checkpoint_file==""):
start_epoch = 0
else:
print("Loading checkpoint '{}' ".format(checkpoint_file))
checkpoint = torch.load(checkpoint_file)
start_epoch = checkpoint['epoch']
print("Loading the model")
##Initialises the model with last classifier layers chaged as paer our needs and weights of the shared model put inside the model
model = model_initialiser(no_of_classes,use_gpu)
model = model.load_state_dict(checkpoint['state_dict'])
print('Loading the optimizer')
optimizer = local_sgd(model.params, scheduler_lambda)
optimizer = optimizer.load_state_dict(checkpoint['optimizer'])
print('Done')
model.xmodel.train(True)
model.xmodel.to(device)
#training Loop starts
for epoch in range(start_epoch,epochs+1):
## Omega accumulation is done at the convergence of the loss function
if(epoch == epochs):
## Notice the fact that no training happens during this
optimizer_ft = omega_update(model.params)
print("Updating the omega values for this task")
## takes the input images calculate gradient and upadte the params
model = compute_omega_grads_norm(model,trdataload,optimizer_ft,use_gpu)
running_loss = 0
running_corrects=0
model.xmodel.eval()
for data in tedataload:
input_data , labels = data
del data
if use_gpu:
input_data = input_data.to(device)
labels = labels.to(device)
else:
input_data = input_data
labels = Variable(labels)
#optimizer.zero_grad()
output = model.xmodel(input_data)
del input_data
_, preds = torch.max(output, 1)
del output
running_corrects += torch.sum(preds == labels.data)
del preds
del labels
epoch_accuracy = running_corrects.double()/test_size
else:
since = time.time()
best_perform = 10e6
print("Training on epoch no {} of {}".format(epoch+1,epochs))
print("-"*20)
running_loss = 0
running_corrects = 0
## returning the optimizer and making it smaller every 20 rounds
optimizer = scheduler(optimizer, epoch,lr)
model.xmodel.train(True)
for input_data,labels in trdataload:
if use_gpu:
input_data = input_data.to(device)
labels = labels.to(device)
else:
## variable is just a wrapper around the tensors
input_data = Variable(input_data)
labels = Variable(labels)
model.xmodel.to(device)
## resets the gradients
optimizer.zero_grad()
output = model.xmodel(input_data)
del input_data
# print(output.size(),"outputshape")
not_req, predictions = torch.max(output,1)
loss = model_criterion(output,labels)
del output
##automatically computes the gradients and changes the parameters jfor which requires_grad is True
loss.backward()
optimizer.step(model.params)
running_loss += loss.item()
del loss
running_corrects += torch.sum(predictions == labels.data)
del predictions
del labels
epoch_loss = running_loss/train_size
## In order to get the accuracy in double we need to have atleast 1 variable of type double
epoch_accuracy = running_corrects.double()/train_size
print("Loss: {} Accuracy:{} ".format(epoch_loss,epoch_accuracy))
# avoiding the filw to be written twice
if(epoch!=0 and epoch != epochs-1 and (epoch+1)%10 ==0):
epoch_file = os.path.join(store_path, str(epoch+1)+".pth.tar")
torch.save({
'epoch': epoch,
'epoch_loss': epoch_loss,
'epoch_accuracy':epoch_accuracy,
'model_state_dict':model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, epoch_file)
save_model(model, task, epoch_accuracy)