-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
122 lines (102 loc) · 3.67 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Libraries
import math
import utils
import model
import torch
import numpy as np
from torch import nn
from tqdm import tqdm
import matplotlib.pyplot as plt
# Model parameters
input_size = 3
hidden_size = 128
n_layers = 1
dropout = 0
classifier_dim = [128,64]
n_classes = 3
# Training parameters
n_epochs = 20
batch_size = 256
init_lr = 4e-3
best_score = math.inf
# Load datasets
x_train_path, y_train_path = './Data/x_train.npy', './Data/y_train.npy'
x_val_path, y_val_path = './Data/x_test.npy', './Data/y_test.npy'
trainSet = utils.CreateDataset(x_train_path, y_train_path)
valSet = utils.CreateDataset(x_val_path, y_val_path)
n_train, n_val = trainSet.__len__(), valSet.__len__()
# Data loader
train_DL = torch.utils.data.DataLoader(trainSet, batch_size = batch_size, shuffle = True)
val_DL = torch.utils.data.DataLoader(valSet, batch_size = n_val, shuffle = False)
# Instanciate the network
network = model.Predictor(input_size, hidden_size, n_layers, dropout, classifier_dim, n_classes)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
network = network.to(device)
# Loss, optimizer and scheduler
criterion = nn.CrossEntropyLoss(reduction = "mean")
optimizer = torch.optim.AdamW(network.parameters(), lr = init_lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs, 1e-8, verbose = False)
list_train_loss = []
list_val_loss = []
# Training
print("Beginning of the training...")
print("")
# Loop on epochs
for epoch in range(1, n_epochs + 1):
# Initiate parameters to follow the evolution of training loss during the epoch
n = 0
n_val = 0
running_loss = 0.
running_loss_val = 0.
tqdm_train = tqdm(train_DL, total = int(len(train_DL)))
tqdm_train.set_description(f"Epoch {epoch}")
# Loop on training batches
for _, inputs in enumerate(tqdm_train):
# Training mode
network.train()
# Clean gradients
optimizer.zero_grad()
# Compute loss
x, y = inputs
x, y = x.to(device), y.to(device)
with torch.set_grad_enabled(True):
logits = network(x)
loss = criterion(logits, y)
loss.backward()
# Update parameters
optimizer.step()
# Update running loss
running_loss += loss.item() * x.size(0)
n += x.size(0)
tqdm_train.set_postfix(loss = (running_loss / n))
list_train_loss.append(running_loss / n)
# Loop on validation batches
network.eval()
with torch.no_grad():
for inputs in val_DL:
x, y = inputs
x, y = x.to(device), y.to(device)
logits = network(x)
loss = criterion(logits, y)
running_loss_val += loss.item() * x.size(0)
n_val += x.size(0)
acc = (torch.sum(torch.argmax(logits, dim = -1) == y) / n_val).item() * 100
print("Validation loss: {:.3f}".format(running_loss_val / n_val))
print("Validation accuracy: {:.1f}".format(acc))
print()
list_val_loss.append(running_loss_val / n_val)
# Update learning rate
scheduler.step()
# Save the best model
if (list_val_loss[-1] < best_score):
best_score = list_val_loss[-1]
torch.save(network, './trainedModel.pt')
# Save the figure of losses evolution
assert len(list_val_loss) == len(list_train_loss)
x = np.asarray([i for i in range(1, len(list_val_loss) + 1)])
plt.figure(figsize = (10,6))
plt.plot(x, np.asarray(list_train_loss), color = 'blue', label = 'Training')
plt.plot(x, np.asarray(list_val_loss), color = 'orange', label = 'Evaluation')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()