-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining.py
226 lines (188 loc) · 10.6 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import os
import torch
import numpy as np
from torch_geometric.loader import DataLoader
from torch.utils.data import WeightedRandomSampler
from misc import print_dataset_stats
from precision_eval import CriterionLoss, CorrectClass
from precision_cmp import train_first_elem_cmp, val_first_elem_cmp
import time
# Training a model on a given data, it returns the accumulated loss
#
def train(model, criterion, optimizer, loader, get_label_f=None, batch_transformer=lambda d : d):
model.train()
i=0
for data in loader:
i += 1
# print(i)
data = batch_transformer(data) # trasfom the batch if needed
out = model(data)
labels = get_label_f(data) # get the label (class for classification and value for regression)
loss = criterion(out, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Test a model on a given data.
#
def test_c(model,criterion,loader, out_suffix='', epoch=None, out_path='/tmp', get_label_f=None, batch_transformer=lambda d : d, precision_evals=[]):
model.eval()
with torch.no_grad():
# reset all precision evaluators
for peval in precision_evals:
peval.reset(epoch=epoch,out_suffix=out_suffix,out_path=out_path)
for data in loader:
data = batch_transformer(data) # trasfrom the batch if needed
out = model(data) # apply model
labels = get_label_f(data) # get labels from data
# evaluate all precision criteria
for peval in precision_evals:
peval.eval(out,labels,data,criterion)
return [ (peval.tag(), peval.loss(), peval.report()) for peval in precision_evals ]
# creates a loader, and balance the set (only in case of classification for now) if needed
#
# TBD: we always use the DataLoader of pyg, seems to works with datasets of torch, but should check it out.
#
def create_loader(dataset, name="dataset", get_class_f_for_balancing=None, balance = False, regression = False, batch_size=64):
def get_class(d):
c = get_class_f_for_balancing(d)
return c.item() if type(c) == torch.Tensor else c
if not regression and balance:
print(f"Balancing the {name} ...", flush=True)
all_labels = [ get_class(d) for d in dataset]
labels_unique, counts = np.unique(all_labels,return_counts=True)
class_weights = { labels_unique[i] : sum(counts)/counts[i] if counts[i] else 0 for i in range(len(counts)) }
data_weights = [ class_weights[get_class(d)] for d in dataset ] # weight for every example in train_set
sampler= WeightedRandomSampler(data_weights, len(data_weights))
loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
else:
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
return loader
# The main training loop
#
def training(model = None, # a model that is suitable for the dataset provided, its forward receives 1 argument (the batch) and is supposed to get a batch and do all work
criterion = None, # loss function
optimizer = None, # an optimizer
dataset = None, # the data set (suitable for the model)
epochs=10, # number of epochs
train_set_size_percentage=0.8, # how to split the set into training and validation
regression=False, # if it is a regression or a classification problem (eventually will be eliminated)
testset=None, # an optional test set, that is supposed to be independent from dataset
get_label_f=None, # a function to extract the labels from a batch (just to support several kinds of datasets)
balance_train_set=True, # if the train set should be balanced when creating a loader (for classification for now)
balance_validation_set=True, # same as above but for the validation set
balance_testset=True, # same as above but for the test set
get_class_f_for_balancing = lambda d : get_label_f(d).item(), # a function that returns the label for the purpose of balancing
batch_size = 64, # the size of each batch created by the loader
batch_transformer=lambda d : d, # in case batches needed to be transformed a bit before sending to the model
precision_evals=[CriterionLoss()], # a list of precision evaluators (see the module precision_eval)
prec_cmp=train_first_elem_cmp, # a comparator of precision for each batch (see the module precision_cmp)
save_models = None, # can be 'all', 'last', or None
save_improved_only = False, # can be 'all', 'last', or None
sim_train = False, # We use this to print statistics of a model that we have already trained (splitting into training and validation set)
out_path = '/tmp'): # where to save the optimal model
# if there is a data set, then we are doing training
if dataset is not None:
print()
print(f'** The training set ({train_set_size_percentage*100:.2f}% for training and {(1-train_set_size_percentage)*100:.2f}% for validation)')
print_dataset_stats(dataset, regression=regression)
if not sim_train:
# split the dataset into training and validation
#
dataset = dataset.shuffle()
train_set_size = int(len(dataset)*train_set_size_percentage)
train_set = dataset[:train_set_size]
validation_set = dataset[train_set_size:]
outfile = open(f'{out_path}/train_set_idx.txt', "w")
for t in train_set:
outfile.write(f"{t[3]['idx']}")
outfile.write('\n')
outfile.close()
outfile = open(f'{out_path}/val_set_idx.txt', "w")
for t in validation_set:
outfile.write(f"{t[3]['idx']}")
outfile.write('\n')
outfile.close()
else:
infile = open(f'{out_path}/train_set_idx.txt', "r")
data = infile.read()
train_set = [ dataset[int(idx)] for idx in data.split() ]
infile.close()
infile = open(f'{out_path}/val_set_idx.txt', "r")
data = infile.read()
validation_set = [ dataset[int(idx)] for idx in data.split() ]
infile.close()
# Create the loaders. In case of classification we might be balancing
train_loader = create_loader(train_set, name="training set", regression=regression, balance=balance_train_set, get_class_f_for_balancing=get_class_f_for_balancing, batch_size=batch_size)
val_loader = create_loader(validation_set, name="validation set", regression=regression, balance=balance_validation_set, get_class_f_for_balancing=get_class_f_for_balancing, batch_size=batch_size)
# best epoch information
best_epoch = None
best_epoch_train_loss = [ float('inf') ] * len(precision_evals)
best_epoch_val_loss = [ float('inf') ] * len(precision_evals)
else:
epochs = 1 # otherwise we are just testing, so we fix the epochs to 1
# create the loader of the testset if provided
if testset is not None:
print()
print(f'** The test set')
print_dataset_stats(testset, regression=regression)
test_loader = create_loader(testset, name="test set", regression=regression, balance=balance_testset, get_class_f_for_balancing=get_class_f_for_balancing, batch_size=batch_size)
print()
t1=t2=t3=t4=t5=t6=t7=t8=0
last_filename = None
for epoch in range(1, epochs+1):
print(f'Epoch {epoch:03d}', end="", flush=True)
# calculate precicion on train/val sets
#
if dataset is not None:
# train
#
t1 = time.time()
if not sim_train:
train(model,criterion,optimizer,train_loader,get_label_f=get_label_f,batch_transformer=batch_transformer)
t2 = time.time()
train_precision = test_c(model,criterion,train_loader,get_label_f=get_label_f,batch_transformer=batch_transformer, precision_evals=precision_evals, epoch=epoch, out_suffix=f'train',out_path=out_path)
t3 = time.time()
val_precision = test_c(model,criterion,val_loader,get_label_f=get_label_f,batch_transformer=batch_transformer, precision_evals=precision_evals, epoch=epoch, out_suffix=f'val',out_path=out_path)
t4 = time.time()
# check if there is an improvement wrt. the best epoch
curr_train_loss = [ loss for (_,loss,_) in train_precision]
curr_val_loss = [ loss for (_,loss,_) in val_precision]
improved=False
if prec_cmp(curr_train_loss,curr_val_loss,best_epoch_train_loss,best_epoch_val_loss):
best_epoch_train_loss = curr_train_loss
best_epoch_val_loss = curr_val_loss
best_epoch = epoch
improved=True
t5 = time.time()
if save_models is not None:
filename = None
if (save_improved_only and improved) or (not save_improved_only):
filename = f'{out_path}/model_{"i_" if improved else ""}{epoch}.pyt'
if filename is not None:
if save_models == 'last' and last_filename is not None:
os.remove(last_filename) # remove last one saved
torch.save(model, filename) # save the new one
last_filename = filename
t6 = time.time()
mark = '*' if improved else ''
print(f'{mark} \t ')
print(f'\tTrain: ', end="")
for (tag,loss,prec_info) in train_precision:
print(f'{tag}={loss:.4f} ({prec_info})', end="")
print("\t", end="")
print()
print(f'\tVal: ', end="")
for (tag,loss,prec_info) in val_precision:
print(f'{tag}={loss:.4f} ({prec_info})', end="")
print("\t", end="")
print(flush=True)
t8 = t7 = time.time()
if testset is not None:
test_precision = test_c(model,criterion,test_loader,get_label_f=get_label_f,batch_transformer=batch_transformer, precision_evals=precision_evals,epoch=epoch, out_suffix=f'test',out_path=out_path)
t8 = time.time()
print(f'\tTest: ', end="")
for (tag,loss,prec_info) in test_precision:
print(f'{tag}={loss:.4f} ({prec_info})', end="")
print("\t", end="")
print(flush=True)
print(f'\tTimes: {(t2-t1):.2f} {(t3-t2):.2f} {(t4-t3):.2f} {(t5-t4):.2f} {(t6-t5):.2f} {(t7-t6):.2f} {(t8-t7):.2f}')