Skip to content

Commit

Permalink
🔨[DEV] Improve MCTS
Browse files Browse the repository at this point in the history
  • Loading branch information
fairyshine committed Sep 29, 2024
1 parent 9dd8903 commit abc97c9
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 43 deletions.
2 changes: 1 addition & 1 deletion src/fastmindapi/algo/tree/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .mcts import MCTS
from .mcts import MCTS_explore as MCTS
124 changes: 84 additions & 40 deletions src/fastmindapi/algo/tree/mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import argparse
from abc import ABC, abstractmethod

# from ... import logger
import logging
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger('MyLogger')
# import logging
# logging.basicConfig(level=logging.WARNING)
# logger = logging.getLogger('MyLogger')

SCALAR=1/(2*math.sqrt(2.0))

Expand Down Expand Up @@ -39,7 +38,7 @@ def __repr__(self):

class MCTSNode:
def __init__(self, state, parent=None):
self.visits=1
self.visits=0
self.reward=0.0
self.state=state
self.children=[]
Expand All @@ -50,50 +49,58 @@ def add_child(self,child_state):
def update(self,reward):
self.reward+=reward
self.visits+=1
def fully_expanded(self, num_moves_lambda) -> bool: # 🌟 rewrite
def fully_expanded(self) -> bool: # 🌟 rewrite
num_moves = self.state.num_moves
if num_moves_lambda is not None:
num_moves = num_moves_lambda(self)
if len(self.children)==num_moves:
return True
return False
def __repr__(self):
s="Node; children: %d; visits: %d; reward: %f"%(len(self.children),self.visits,self.reward)
return s

class MCTS:
class MCTS_raw:
@classmethod
def UCT_search(cls, budget, root, num_moves_lambda = None):
def UCT_search(cls, budget, root, logger=None, log_freq: int=100):
for iter in range(int(budget)):
if iter%10000==9999:
logger.info("simulation: %d"%iter)
logger.info(root)
front=cls.tree_policy(root, num_moves_lambda)
if logger is not None:
if iter%log_freq==log_freq-1:
logger.info("simulation: %d"%(iter+1))
logger.info(root)
front=cls.tree_policy(root)
reward=cls.default_policy(front.state)
cls.backup(front,reward)
return cls.best_child(root,0)

@classmethod
def tree_policy(cls, node, num_moves_lambda):
#a hack to force 'exploitation' in a game where there are many options, and you may never/not want to fully expand first
def tree_policy(cls, node):
while not node.state.terminal():
if len(node.children)==0:
if not node.fully_expanded():
return cls.expand(node)
elif random.uniform(0,1)<.5:
node=cls.best_child(node,SCALAR)
else:
if not node.fully_expanded(num_moves_lambda):
return cls.expand(node)
else:
node=cls.best_child(node,SCALAR)
node=cls.best_child(node,SCALAR)

# a hack to force 'exploitation' in a game where there are many options, and you may never/not want to fully expand first
# while not node.state.terminal():
# if len(node.children)==0:
# return cls.expand(node)
# elif random.uniform(0,1)<.5:
# node=cls.best_child(node,SCALAR)
# else:
# if not node.fully_expanded():
# return cls.expand(node)
# else:
# node=cls.best_child(node,SCALAR)
return node

@staticmethod
def expand(node):
# 统计所有已有动作
tried_children=[c.state for c in node.children]
# 循环直到找到一个未尝试过的新动作
new_state=node.state.next_state()
while new_state in tried_children and not new_state.terminal():
while new_state in tried_children and not new_state.terminal(): # 用hash判断是否in
new_state=node.state.next_state()
# 添加新动作
node.add_child(new_state)
return node.children[-1]

Expand All @@ -111,14 +118,15 @@ def best_child(node,scalar):
if score>bestscore:
bestchildren=[c]
bestscore=score
if len(bestchildren)==0:
logger.warn("OOPS: no best child found, probably fatal")
return random.choice(bestchildren) if bestchildren != [] else None

@staticmethod
def default_policy(state):
while state.terminal()==False:
state=state.next_state()
if hasattr(state, 'simulate'):
state = state.simulate()
else:
while not state.terminal():
state=state.next_state()
return state.reward()

@staticmethod
Expand All @@ -129,6 +137,22 @@ def backup(node,reward):
node=node.parent
return None

class MCTS_explore(MCTS_raw):
@classmethod
def tree_policy(cls, node):
# a hack to force 'exploitation' in a game where there are many options, and you may never/not want to fully expand first
while not node.state.terminal():
if len(node.children)==0:
return cls.expand(node)
elif random.uniform(0,1)<.5:
node=cls.best_child(node,SCALAR)
else:
if not node.fully_expanded():
return cls.expand(node)
else:
node=cls.best_child(node,SCALAR)
return node

if __name__=="__main__":
class TESTState(MCTSState):
NUM_TURNS = 10
Expand Down Expand Up @@ -166,16 +190,36 @@ def __repr__(self):
parser.add_argument('--levels', action="store", required=True, type=int, choices=range(TESTState.NUM_TURNS+1))
args=parser.parse_args()

current_node=MCTSNode(TESTState())
for l in range(args.levels):
current_node=MCTS.UCT_search(args.num_sims/(l+1),current_node)
print("level %d"%l)
print("Num Children: %d"%len(current_node.children))
for i,c in enumerate(current_node.children):
print(i,c)
print("Best Child: %s"%current_node.state)
# while current_node is not None:
# print("Best Child: %s"%current_node.state)
# current_node = MCTS.best_child(current_node, 0)
print("--------------------------------")
test_time = 1000
score_1 = 0
score_2 = 0

best_moves = [-20, 27, -16, 14, 18, -15, -8, 6, -4, -2]

for i in range(test_time):
current_node=MCTSNode(TESTState())
current_node=MCTS.UCT_search(args.num_sims,current_node)
while current_node.children != []:
# print("Best Child: %s"%current_node.state)
current_node = MCTS.best_child(current_node, 0)
print(current_node.state.moves)
if abs(current_node.state.moves[-1]) == 2:
score_1 += 1

print("----------------------------------")


for i in range(test_time):
current_node=MCTSNode(TESTState())
current_node=MCTS_explore.UCT_search(args.num_sims,current_node)
while current_node.children != []:
# print("Best Child: %s"%current_node.state)
current_node = MCTS.best_child(current_node, 0)
print(current_node.state.moves)
if abs(current_node.state.moves[-1]) == 2:
score_2 += 1

print("Score 1: ", score_1)
print("Score 2: ", score_2)

# Command: python src/fastmindapi/algo/tree/mcts.py --num_sims 10000 --levels 8
2 changes: 1 addition & 1 deletion src/fastmindapi/model/llama_cpp/LLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __init__(self, model):
@classmethod
def from_path(cls, model_path: str):
from llama_cpp import Llama
return cls(Llama(model_path, n_gpu_layers=-1, logits_all=True))
return cls(Llama(model_path, n_gpu_layers=-1, logits_all=True, n_ctx=2048))

def __call__(self, input_text: str, max_new_tokens: int=256):
response = self.model(input_text, max_tokens=max_new_tokens)
Expand Down
2 changes: 1 addition & 1 deletion src/fastmindapi/model/openai/ChatModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __call__(self, input_text: str, max_new_tokens: int = 256):
return completion.choices[0].message.content
except Exception as e:
return "【Error】: " + str(e)

def generate(self,
input_text: str,
max_new_tokens: int = 256,
Expand Down

0 comments on commit abc97c9

Please sign in to comment.