Skip to content

Commit

Permalink
🔨[DEV] Improve MCTS algo
Browse files Browse the repository at this point in the history
  • Loading branch information
fairyshine committed Oct 12, 2024
1 parent 27bc64f commit d542e37
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions src/fastmindapi/algo/tree/mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@

SCALAR=1/(2*math.sqrt(2.0))

def calculate_score(node, scalar):
exploit = node.reward / node.visits
explore = math.sqrt(2.0 * math.log(node.visits) / float(node.visits))
return exploit + scalar * explore

class MCTSState(ABC):
MOVES=[]
num_moves=len(MOVES)
Expand Down Expand Up @@ -77,7 +82,7 @@ def UCT_search(cls, budget, root, logger=None, log_freq: int=10):
front=cls.tree_policy(root)
reward=cls.default_policy(front.state)
cls.backup(front,reward)
return cls.best_child(root,0)
return cls.best_leaf(root,SCALAR)

@classmethod
def tree_policy(cls, node):
Expand Down Expand Up @@ -117,17 +122,21 @@ def expand(node):
def best_child(node,scalar):
bestscore=0.0
bestchildren=[]
for c in node.children:
exploit=c.reward/c.visits
explore=math.sqrt(2.0*math.log(node.visits)/float(c.visits))
score=exploit+scalar*explore
for c in node.children:
score = calculate_score(c, scalar)
if score==bestscore:
bestchildren.append(c)
if score>bestscore:
bestchildren=[c]
bestscore=score
return random.choice(bestchildren) if bestchildren != [] else None

@classmethod
def best_leaf(cls, node, scalar):
while node.children != []:
node = cls.best_child(node, scalar)
return node

@staticmethod
def default_policy(state):
if not state.terminal():
Expand Down

0 comments on commit d542e37

Please sign in to comment.