-
Notifications
You must be signed in to change notification settings - Fork 109
/
train_logger.py
36 lines (30 loc) · 1.19 KB
/
train_logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import time
class TrainLogger(object):
def __init__(self, batch_size, frequency=50, num_gpus=1):
self.num_gpus = num_gpus
self.batch_size = batch_size * num_gpus
self.frequency = frequency
self.init = False
self.tic = 0
self.last_batch = 0
self.running_loss = 0
def __call__(self, epoch, total_epochs, batch, total, loss):
if self.last_batch > batch:
self.init = False
self.last_batch = batch
if self.init:
self.running_loss += loss
if batch % self.frequency == 0:
speed = self.frequency * self.batch_size / (time.time() - self.tic)
self.running_loss = self.running_loss / self.frequency
batch, total = batch * self.num_gpus, total * self.num_gpus
log = (
f"Epoch: [{epoch + 1}-{total_epochs}] Batch: [{batch}-{total}] "
+ f"Speed: {speed:.2f} samples/sec Loss: {self.running_loss:.5f}"
)
print(log)
self.running_loss = 0
self.tic = time.time()
else:
self.init = True
self.tic = time.time()