-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpgn_reader.py
executable file
·161 lines (110 loc) · 3.76 KB
/
pgn_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#!/usr/bin/env python3
import chess.pgn
from chess_input import Repr2D
import random
import re
from prometheus_client import Counter
from dataclasses import dataclass, field
from typing import Any
# Maximum priority to assign an item in the position queue
MAX_PRIO = 1_000_000
@dataclass(order=True)
class PrioritizedItem:
priority: int
data_board: Any = field(compare=False)
label_moves: Any = field(compare=False)
label_value: Any = field(compare=False)
label_wdl: Any = field(compare=False)
def label_for_result(result, turn):
if result == "1-0":
if turn:
return [1, 0, 0]
else:
return [0, 0, 1]
if result == "0-1":
if turn:
return [0, 0, 1]
else:
return [1, 0, 0]
return [0, 1, 0]
repr = Repr2D()
re1 = re.compile("q=(.*); p=\[(.*)\]")
re2 = re.compile("(.*):(.*)")
def parse_mcts_result(input):
m = re1.match(input)
if m is None:
return None, None
q = float(m.group(1))
variations = m.group(2).split(", ")
v = {}
for variation in variations:
m2 = re2.match(variation)
if m2 is not None:
v[m2.group(1)] = float(m2.group(2))
return q, v
def randomize_item(item):
item.priority = random.randint(0, MAX_PRIO)
return item
def traverse_game(node, board, queue, result, sample_rate, follow_variations=False):
positions_created = 0
if not follow_variations and not node.is_mainline():
return positions_created
move = node.move
if node.comment and random.randint(0, 100) < sample_rate:
q, policy = parse_mcts_result(node.comment)
q = q * 2 - 1.0
z = label_for_result(result, board.turn)
train_data_board = repr.board_to_array(board)
train_labels1 = repr.policy_to_array(board, policy)
item = PrioritizedItem(
random.randint(0, MAX_PRIO), train_data_board, train_labels1, q, z
)
queue.put(item)
positions_created += 1
if move is not None:
board.push(move)
for sibling in node.variations:
positions_created += traverse_game(sibling, board, queue, result, sample_rate)
if move is not None:
board.pop()
return positions_created
# Counter for monitoring no. of games
game_counter = Counter("training_game_total", "Games seen by training", ["result"])
def pos_generator(filename, test_mode, queue):
sample_rate = 100 if test_mode else 50
cnt = 0
with open(filename) as pgn:
positions_created = 0
while positions_created < 2500000:
skip_training = False
try:
game = chess.pgn.read_game(pgn)
except UnicodeDecodeError or ValueError:
continue
if game is None:
break
result = game.headers["Result"]
white = game.headers["White"]
black = game.headers["Black"]
date_of_game = game.headers["Date"]
game_counter.labels(result=result).inc()
cnt += 1
if cnt % 10 == 0:
print(
"Parsing game #{} {}, {} positions (avg {:.1f} pos/game)".format(
cnt,
date_of_game,
positions_created,
positions_created / cnt,
),
end="\r",
)
positions_created += traverse_game(
game, game.board(), queue, result, sample_rate
)
print(
f"Parsed {cnt} games, {positions_created} positions (avg {positions_created / cnt:.1f} pos/game)."
)
queue.put(end_of_input_item())
def end_of_input_item():
return PrioritizedItem(MAX_PRIO, None, None, None, None)