Skip to content

Commit

Permalink
⚠️[FIX] correct MCTS algo
Browse files Browse the repository at this point in the history
  • Loading branch information
fairyshine committed Oct 11, 2024
1 parent 38b0b6c commit 27bc64f
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions src/fastmindapi/algo/tree/mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# logging.basicConfig(level=logging.WARNING)
# logger = logging.getLogger('MyLogger')

# from ... import logger

SCALAR=1/(2*math.sqrt(2.0))

class MCTSState(ABC):
Expand Down Expand Up @@ -48,9 +50,9 @@ def __init__(self, state, parent=None):
self.reward=0.0
self.state=state
self.children=[]
self.parent=parent
def add_child(self,child_state):
child=MCTSNode(child_state,self)
self.parent=parent
def add_child(self, child_state):
child=self.__class__(child_state,self)
self.children.append(child)
def update(self,reward):
self.reward+=reward
Expand All @@ -66,11 +68,11 @@ def __repr__(self):

class MCTS_raw:
@classmethod
def UCT_search(cls, budget, root, logger=None, log_freq: int=100):
def UCT_search(cls, budget, root, logger=None, log_freq: int=10):
for iter in range(int(budget)):
if logger is not None:
if iter%log_freq==log_freq-1:
logger.info("simulation: %d"%(iter+1))
logger.info("【Iteration】: %d"%(iter+1))
logger.info(root)
front=cls.tree_policy(root)
reward=cls.default_policy(front.state)
Expand Down Expand Up @@ -128,11 +130,12 @@ def best_child(node,scalar):

@staticmethod
def default_policy(state):
if hasattr(state, 'simulate'):
state = state.simulate()
else:
while not state.terminal():
state=state.next_state()
if not state.terminal():
if hasattr(state, 'simulate'):
state = state.simulate()
else:
while not state.terminal():
state=state.next_state()
return state.reward()

@staticmethod
Expand Down

0 comments on commit 27bc64f

Please sign in to comment.