-
Notifications
You must be signed in to change notification settings - Fork 0
/
mcts.py
201 lines (169 loc) · 7.56 KB
/
mcts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
"""蒙特卡洛树搜索"""
import numpy as np
import copy
from config import CONFIG
def softmax(x):
probs = np.exp(x - np.max(x))
probs /= np.sum(probs)
return probs
# 定义叶子节点
class TreeNode(object):
"""
mcts树中的节点,树的子节点字典中,键为动作,值为TreeNode。记录当前节点选择的动作,以及选择该动作后会跳转到的下一个子节点。
每个节点跟踪其自身的Q(Q不断更新),先验概率P(神经网络给出的预测)及其访问次数调整的u
"""
def __init__(self, parent, prior_p):
"""
:param parent: 当前节点的父节点
:param prior_p: 当前节点被选择的先验概率 神经网络预测值
"""
self._parent = parent
self._children = {} # 从动作到TreeNode的映射
self._n_visits = 0 # 当前当前节点的访问次数
self._Q = 0 # 当前节点对应动作的平均动作价值
self._u = 0 # 当前节点的置信上限 # PUCT算法
self._P = prior_p
def expand(self, action_priors): # 这里把不合法的动作概率全部设置为0
"""通过创建新子节点来展开树"""
for action, prob in action_priors:
if action not in self._children:
self._children[action] = TreeNode(self, prob)
def select(self, c_puct):
"""
在子节点中选择能够提供最大的Q+U的节点
return: (action, next_node)的二元组
"""
return max(self._children.items(),
key=lambda act_node: act_node[1].get_value(c_puct))
def get_value(self, c_puct):
"""
计算并返回此节点的值,它是节点评估Q和此节点的先验的组合
c_puct: 控制相对影响(0, inf)
"""
self._u = (c_puct * self._P *
np.sqrt(self._parent._n_visits) / (1 + self._n_visits)) # puct算法实现
return self._Q + self._u
def update(self, leaf_value):
"""
从叶节点评估中更新节点值
leaf_value: 这个子节点的评估值来自当前玩家的视角
"""
# 统计访问次数
self._n_visits += 1
# 更新Q值,取决于所有访问次数的平均树,使用增量式更新方式 (速度更快,节省内存)
self._Q += 1.0 * (leaf_value - self._Q) / self._n_visits
# 使用递归的方法对所有节点(当前节点对应的支线)进行一次更新
def update_recursive(self, leaf_value):
"""就像调用update()一样,但是对所有直系节点进行更新"""
# 如果它不是根节点,则应首先更新此节点的父节点
if self._parent:
self._parent.update_recursive(-leaf_value)
self.update(leaf_value)
def is_leaf(self):
"""检查是否是叶节点,即没有被扩展的节点"""
return self._children == {}
def is_root(self):
return self._parent is None
# 蒙特卡洛搜索树
class MCTS(object):
# param n_playout: 每次mcts搜索的次数
def __init__(self, policy_value_fn, c_puct=5, n_playout=2000):
"""policy_value_fn: 接收board的盘面状态,返回落子概率和盘面评估得分"""
self._root = TreeNode(None, 1.0) # 父节点走子概率为1
self._policy = policy_value_fn
self._c_puct = c_puct
self._n_playout = n_playout
def _playout(self, state):
"""
进行一次搜索,根据叶节点的评估值进行反向更新树节点的参数
注意:state已就地修改,因此必须提供副本
"""
node = self._root
# 循环找到叶子节点为止
while True:
if node.is_leaf():
break
# 贪心算法选择下一步行动
action, node = node.select(self._c_puct)
state.do_move(action)
# 使用网络评估叶子节点,网络输出(动作,概率)元组p的列表以及当前玩家视角的得分[-1, 1] # 此处为估值
# 使用神经网络进行更新,可以减少搜素次数,即不必走到终局
action_probs, leaf_value = self._policy(state)
# 查看游戏是否结束
end, winner = state.game_end()
if not end:
node.expand(action_probs)
else:
# 对于结束状态,将叶子节点的值换成1或-1 # 即根据真实的值来训练神经网络
if winner == -1: # Tie
leaf_value = 0.0
else:
leaf_value = (
1.0 if winner == state.get_current_player_id() else -1.0
)
# 在本次遍历中更新节点的值和访问次数
# 必须添加符号,因为两个玩家在模拟的时候共用一个搜索树
node.update_recursive(-leaf_value)
def get_move_probs(self, state, temp=1e-3):
"""
按顺序运行所有搜索并返回可用的动作及其相应的概率
state:当前游戏的状态
temp:介于(0, 1]之间的温度参数
"""
for n in range(self._n_playout):
state_copy = copy.deepcopy(state)
self._playout(state_copy)
# 跟据根节点处的访问计数来计算移动概率
act_visits= [(act, node._n_visits)
for act, node in self._root._children.items()]
acts, visits = zip(*act_visits)
act_probs = softmax(1.0 / temp * np.log(np.array(visits) + 1e-10))
return acts, act_probs
def update_with_move(self, last_move):
"""
在当前的树上向前一步,保持我们已经知道的关于子树的一切,即更新根节点到走子的位置等信息
"""
if last_move in self._root._children:
self._root = self._root._children[last_move]
self._root._parent = None
else:
self._root = TreeNode(None, 1.0)
def __str__(self):
return 'MCTS'
# 基于MCTS的AI玩家
class MCTSPlayer(object):
def __init__(self, policy_value_function, c_puct=5, n_playout=2000, is_selfplay=0):
self.mcts = MCTS(policy_value_function, c_puct, n_playout)
self._is_selfplay = is_selfplay
self.agent = "AI"
# 设置颜色
def set_player_ind(self, p):
self.player = p
# 重置搜索树
def reset_player(self):
self.mcts.update_with_move(-1)
def __str__(self):
return 'MCTS {}'.format(self.player)
# 得到行动
def get_action(self, board, temp=1e-3, return_prob=0):
# 像alphaGo_Zero论文一样使用MCTS算法返回的pi向量
move_probs = np.zeros(2086) # 先使用0来占位
acts, probs = self.mcts.get_move_probs(board, temp) # 拿到走子概率
move_probs[list(acts)] = probs # 用搜索出的概率替代占位
if self._is_selfplay:
# 添加Dirichlet Noise进行探索(自我对弈需要)
move = np.random.choice(
acts,
p=0.75*probs + 0.25*np.random.dirichlet(CONFIG['dirichlet'] * np.ones(len(probs))) # 根据论文中设置的0.75
)
# 更新根节点并重用搜索树
self.mcts.update_with_move(move)
else:
# 使用默认的temp=1e-3,它几乎相当于选择具有最高概率的移动
move = np.random.choice(acts, p=probs)
# 重置根节点
self.mcts.update_with_move(-1)
if return_prob:
return move, move_probs
else:
return move