-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_valid_test.py
73 lines (53 loc) · 1.98 KB
/
train_valid_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# -*- coding: utf-8 -*-
import torch
from utils import timeit
from torch.autograd import Variable
#@timeit
def train_epoch(model, tr_loader, criterion, optimizer, lr, results):
train_loss = 0
correct, total = 0, 0
# Run minibaches from the training dataset
for i, (X, labels) in enumerate(tr_loader):
X, labels = Variable(X), Variable(labels)
# Forward pass
model.zero_grad()
y_pred = model(X)
s, preds = torch.max(y_pred.data, 1)
# Compute loss
loss = criterion(y_pred, labels)
# Backward pass
loss.backward()
optimizer.step()
# Collect stats
train_loss += loss.item()
model.collect_stats(lr)
# Compute and store epoch results
total += y_pred.size(0)
correct += int(sum(preds == labels))
# if i % 20 == 0: print(correct/total)
lss = round((train_loss / i+1), 3)
acc = round((correct / total) * 100, 2)
results.train_accy.append(acc)
results.train_loss.append(lss)
return lss, acc
def valid_epoch(model, ts_loader, criterion, results):
valid_loss = 0
correct, total = 0, 0
with torch.no_grad():
for i, (X, labels) in enumerate(ts_loader):
X, labels = Variable(X), Variable(labels)
# Forward pass
y_pred = model(X)
s, preds = torch.max(y_pred.data, 1)
# Compute loss
loss = criterion(y_pred, labels)
valid_loss += loss.item()
# Compute and store epoch results
total += y_pred.size(0)
correct += int(sum(preds == labels))
# if i % 20 == 0: print(correct/total)
lss = round((valid_loss/i+1), 3)
acc = round((correct / total) * 100, 3)
results.valid_loss.append(lss)
results.valid_accy.append(acc)
return lss, acc