-
Notifications
You must be signed in to change notification settings - Fork 35
/
train.py
79 lines (47 loc) · 1.84 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import logging
import torch
# from torch.utils.tensorboard import SummaryWriter
from utils.utils_common import DataModes
import numpy as np
import time
import wandb
from IPython import embed
import time
logger = logging.getLogger(__name__)
class Trainer(object):
def training_step(self, data, epoch):
# Get the minibatch
self.optimizer.zero_grad()
loss, log = self.net.loss(data, epoch)
loss.backward()
self.optimizer.step()
# embed()
return log
def __init__(self, net, trainloader, optimizer, numb_of_itrs, eval_every, save_path, evaluator):
self.net = net
self.trainloader = trainloader
self.optimizer = optimizer
self.numb_of_itrs = numb_of_itrs
self.eval_every = eval_every
self.save_path = save_path
self.evaluator = evaluator
def train(self, start_iteration=1):
print("Start training...")
self.net = self.net.train()
iteration = start_iteration
print_every = 1
for epoch in range(10000000): # loop over the dataset multiple times
for itr, data in enumerate(self.trainloader):
# training step
loss = self.training_step(data, start_iteration)
if iteration % print_every == 0:
log_vals = {}
for key, value in loss.items():
log_vals[key] = value / print_every
log_vals['iteration'] = iteration
iteration = iteration + 1
if iteration % self.eval_every == self.eval_every-1: # print every K epochs
self.evaluator.evaluate(iteration)
if iteration > self.numb_of_itrs:
break
logger.info("... end training!")