-
Notifications
You must be signed in to change notification settings - Fork 6
/
MCTS.py
131 lines (102 loc) · 4.11 KB
/
MCTS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from API import *
from programGraph import *
from pointerNetwork import *
import time
class MCTS(Solver):
"""
AlphaZero-style Monte Carlo tree search
Currently ignores learned distance / value, but is biased by learned policy
"""
def __init__(self, model, _=None, reward=None,
c_puct=5, rolloutDepth=None):
"""
c_puct: Trades off exploration and exploitation. Larger values favor exploration, guided by policy.
reward: function from loss to reward.
"""
assert reward is not None, "reward must be specified. This function converts loss into reward."
self.reward = reward
self.c_puct = c_puct
self.model = model
self.rolloutDepth = rolloutDepth
self.beamTime = 0.
self.rollingTime = 0.
def __str__(self):
return f"MCTS(puct={self.c_puct})"
def _infer(self, spec, loss, timeout):
startTime = time.time()
owner = self
class Node:
def __init__(self, graph):
self.graph = graph
self.visits = 0
self.edges = []
self.generator = owner.model.bestFirstEnumeration(specEncoding, graph, objectEncodings)
class Edge:
def __init__(self, parent, child, logLikelihood):
self.logLikelihood = logLikelihood
self.parent = parent
self.child = child
self.traversals = 0
self.totalReward = 0
specEncoding = self.model.specEncoder(spec)
objectEncodings = ScopeEncoding(self.model, spec)
def expand(n):
"""Adds a single child to a node"""
if n.generator is None: return
t0 = time.time()
try: o, ll = next(n.generator)
except StopIteration:
n.generator = None
o, ll = None, None
self.beamTime += time.time() - t0
if o is None or o in n.graph.nodes: return
newGraph = n.graph.extend(o)
if newGraph in graph2node:
child = graph2node[newGraph]
else:
self._report(newGraph)
child = Node(newGraph)
e = Edge(n, child, ll)
n.edges.append(e)
def rollout(g):
t0 = time.time()
depth = 0
while True:
samples = self.model.repeatedlySample(specEncoding, g, objectEncodings, 1)
assert len(samples) <= 1
depth += 1
if len(samples) == 0 or samples[0] is None: break
g = g.extend(samples[0])
if self.rolloutDepth is not None and depth >= self.rolloutDepth: break
self.rollingTime += time.time() - t0
self._report(g)
return g
def uct(e):
# Exploit: rewards Q(s,a)
if e.traversals == 0: q = 0.
else: q = e.totalReward/e.traversals
# Explore, biased by policy
exploration_bonus = math.exp(e.logLikelihood) * (e.parent.visits**0.5) / (1. + e.traversals)
# Trade-off of exploration and exploitation
return q + self.c_puct*exploration_bonus
rootNode = Node(ProgramGraph([]))
graph2node = {ProgramGraph([]): rootNode}
while time.time() - startTime < timeout:
n = rootNode
trajectory = [] # list of traversed edges
while len(n.edges) > 0:
e = max(n.edges, key=uct)
trajectory.append(e)
n = e.child
r = self.reward(self.loss(rollout(n.graph)))
# Expand nodes if their single visit-0 child was visited
for e in trajectory:
if e.child.visits == 0:
expand(e.parent)
# back up the reward
for e in trajectory:
e.totalReward += r
e.traversals += 1
e.parent.visits += 1
expand(n)
n.visits += 1