-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
142 lines (116 loc) · 3.73 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
from torch.autograd import Variable
class Trainer(object):
def __init__(self, scheduler, optimizer, criterion):
self.scheduler = scheduler
self.optimizer = optimizer
self.criterion = criterion
self._gpu = False
self._debug = False
@property
def gpu(self):
"""
Use the GPU instead of the CPU.
"""
return self._gpu
@gpu.setter
def gpu(self, value):
self._gpu = value
@property
def debug(self):
"""
Prints stats if set to True.
"""
return self._debug
@debug.setter
def debug(self, value):
self._debug = value
def train(self, model, dataloader, dataset_size):
"""
Train the model using the specified dataloader,
:param model: model
:param dataloader: DataLoader
:param dataset_size: The number of images in the dataset
"""
self.scheduler.step()
model.train(True)
loss, accuracy = self.learn(model, dataloader)
if self._debug:
print('Train Loss: {:.4f} Accuracy: {:.4f}'.format(
loss / dataset_size,
accuracy / dataset_size
))
def validate(self, model, dataloader, dataset_size):
"""
Validate the model using the specified dataloader,
:param model: model
:param dataloader: DataLoader
:param dataset_size: The number of images in the dataset
"""
model.train(False)
loss, accuracy = self.learn(model, dataloader)
if self._debug:
print('Validation Loss: {:.4f} Accuracy: {:.4f}'.format(
loss / dataset_size,
accuracy / dataset_size
))
def learn(self, model, dataloader):
"""
Feed-forward all the data in the dataloader.
:param model: model
:param dataloader: Dataloader
:returns: loss, accuracy
"""
loss = 0.0
accuracy = 0
for data in dataloader:
loss, accuracy = self.evaluate_data(model, data)
loss += loss
accuracy += accuracy
return loss, accuracy
def evaluate_data(self, model, data, persist=False):
"""
Feed-forward the data into the model.
:param model: model
:param data: np.array
:param persist: boolean True if model should use back propagation
:returns: loss, accuracy
"""
inputs, labels = self.create_vars(data)
self.optimizer.zero_grad()
outputs = model(inputs)
_, predictions = torch.max(outputs.data, 1)
loss = self.criterion(outputs, labels)
if persist:
loss.backward()
self.optimizer.step()
return self.calc_loss(loss, inputs), self.calc_accuracy(labels, predictions)
def create_vars(self, data):
"""
Generate the input and label variables.
:param data: np.array
:returns: Variable, Variable
"""
inputs, labels = data
if self._gpu:
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
else:
inputs, labels = Variable(inputs), Variable(labels)
return inputs, labels
def calculate_loss(self, loss, inputs):
"""
Calculate the total loss.
:param loss: np.array
:param inputs: np.array
:returns: int
"""
return loss.data[0] * inputs.size(0)
def calc_accuracy(self, labels, predictions):
"""
Calculate the predictions accuracy.
:param labels: np.array
:param predictions: np.array
:returns: np.array
"""
return torch.sum(predictions == labels.data)