-
Notifications
You must be signed in to change notification settings - Fork 0
/
move_selection.py
executable file
·76 lines (56 loc) · 1.9 KB
/
move_selection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#!/usr/bin/env python3
from math import log
from random import uniform
import numpy as np
def sample_gumbel(a):
b = [log(x) - log(-log(uniform(0, 1))) for x in a]
return np.argmax(b)
def select_root_move(tree, move_count, sample=True):
if len(tree.children) == 0:
return None
k = 2.0
moves = []
visits = []
for key, val in tree.children.items():
if val.visit_count > 0:
moves.append(key)
visits.append(val.visit_count ** k)
if sample and move_count < 15:
idx = sample_gumbel(visits)
else:
idx = np.argmax(visits)
return moves[idx]
def select_root_move_delta(tree, move_count, sample=True, delta=0.02):
if len(tree.children) == 0:
return None
_, best_value = max(
((child.visit_count, child.value()) for action, child in tree.children.items()),
key=lambda e: e[0],
)
k = 2.0
moves = []
visits = []
for key, val in tree.children.items():
if val.visit_count > 0 and val.value() > (best_value - delta):
moves.append(key)
visits.append(val.visit_count ** k)
# print("{} {:4.1f}% {:4d}".format(key, 100*val.value(), val.visit_count))
if sample and move_count < 15:
idx = sample_gumbel(visits)
else:
idx = np.argmax(visits)
return moves[idx]
def add_exploration_noise(node):
root_dirichlet_alpha = 0.3
root_exploration_fraction = 0.25
actions = node.children.keys()
noise = np.random.gamma(root_dirichlet_alpha, 1, len(actions))
noise /= np.sum(noise)
frac = root_exploration_fraction
for a, n in zip(actions, noise):
node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac
def add_bias_move(node, move):
node.children[move].prior += 0.9
s = sum([n.prior for m, n in node.children.items()])
for m, n in node.children.items():
n.prior /= s