-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_stats.py
79 lines (66 loc) · 2.92 KB
/
train_stats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from time import strftime
from prometheus_client import Gauge
loss_gauge = Gauge("training_loss", "Training loss")
moves_accuracy_gauge = Gauge("training_move_accuracy", "Move accuracy")
moves_top5_accuracy_gauge = Gauge("training_move_top5_accuracy", "Top 5 move accuracy")
score_mae_gauge = Gauge("training_score_mae", "Score mean absolute error")
wdl_accuracy_gauge = Gauge("wdl_accuracy", "WDL accuracy")
class Stats(object):
def __init__(self):
self.sum_moves_accuracy = 0
self.sum_moves_top5_accuracy = 0
self.sum_score_mae = 0
self.sum_loss = 0
self.sum_wdl_accuracy = 0
self.sum_cnt = 0
def __call__(self, step_output, cnt):
loss = step_output[0]
moves_loss = step_output[1]
score_loss = step_output[2]
wdl_loss = step_output[3]
reg_loss = abs(loss - moves_loss - score_loss - 0.25 * wdl_loss)
moves_accuracy = step_output[4]
moves_top5_accuracy = step_output[5]
score_mae = step_output[6]
wdl_accuracy = step_output[7]
loss_gauge.set(loss)
moves_accuracy_gauge.set(moves_accuracy * 100)
moves_top5_accuracy_gauge.set(moves_top5_accuracy * 100)
score_mae_gauge.set(score_mae)
wdl_accuracy_gauge.set(wdl_accuracy * 100)
self.sum_moves_accuracy += moves_accuracy * cnt
self.sum_moves_top5_accuracy += moves_top5_accuracy * cnt
self.sum_score_mae += score_mae * cnt
self.sum_loss += loss * cnt
self.sum_wdl_accuracy += wdl_accuracy * cnt
self.sum_cnt += cnt
return "loss: {:.2f} = {:.2f} + {:.3f} + {:.3f}, moves: {:4.1f}% top 5: {:4.1f}%, score: {:.2f}, wdl: {:4.1f}% || avg: {:.3f}, {:.2f}% top 5: {:.2f}%, {:.3f}, wdl: {:.2f}%".format(
loss,
moves_loss,
score_loss,
reg_loss,
moves_accuracy * 100,
moves_top5_accuracy * 100,
score_mae,
wdl_accuracy * 100,
self.sum_loss / self.sum_cnt,
self.sum_moves_accuracy * 100 / self.sum_cnt,
self.sum_moves_top5_accuracy * 100 / self.sum_cnt,
self.sum_score_mae / self.sum_cnt,
self.sum_wdl_accuracy * 100 / self.sum_cnt,
)
def write_to_file(self, model_name, filename="stats.txt"):
with open(filename, "a") as statsfile:
print(
"{} [{}] {} positions: {:.3f}, {:.2f}% top 5: {:.2f}%, {:.3f}, wdl: {:.2f}%".format(
strftime("%Y-%m-%d %H:%M"),
model_name,
self.sum_cnt,
self.sum_loss / self.sum_cnt,
self.sum_moves_accuracy * 100 / self.sum_cnt,
self.sum_moves_top5_accuracy * 100 / self.sum_cnt,
self.sum_score_mae / self.sum_cnt,
self.sum_wdl_accuracy * 100 / self.sum_cnt,
),
file=statsfile,
)