-
Notifications
You must be signed in to change notification settings - Fork 0
/
node.py
61 lines (45 loc) · 1.43 KB
/
node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import random
from math import log, sqrt
class Node:
def __init__(self, state):
self.state = state
self.win_value = 0
self.policy_value = None
self.visits = 0
self.parent = None
self.children = []
self.expanded = False
self.player_number = None
self.discovery_factor = 0.35
def update_win_value(self, value):
self.win_value += value
self.visits += 1
if self.parent:
self.parent.update_win_value(value)
def update_policy_value(self, value):
self.policy_value = value
def add_child(self, child):
self.children.append(child)
child.parent = self
def add_children(self, children):
for child in children:
self.add_child(child)
def get_preferred_child(self, root_node):
best_children = []
best_score = float('-inf')
for child in self.children:
score = child.get_score(root_node)
if score > best_score:
best_score = score
best_children = [child]
elif score == best_score:
best_children.append(child)
return random.choice(best_children)
def get_score(self, root_node):
discovery_operand = self.discovery_factor * (self.policy_value or 1) * sqrt(log(self.parent.visits) / (self.visits or 1))
win_multiplier = 1 if self.parent.player_number == root_node.player_number else -1
win_operand = win_multiplier * self.win_value / (self.visits or 1)
self.score = win_operand + discovery_operand
return self.score
def is_scorable(self):
return self.visits or self.policy_value != None