First commit

This commit is contained in:
2026-03-23 21:19:29 +01:00
commit 29fc731e6c
7 changed files with 1173 additions and 0 deletions

190
mcts.py Normal file
View File

@@ -0,0 +1,190 @@
import torch
import math
import numpy as np
def ucb_score(parent, child):
"""
The score for an action that would transition between the parent and child.
"""
prior_score = child.prior * math.sqrt(parent.visit_count) / (child.visit_count + 1)
if child.visit_count > 0:
# The value of the child is from the perspective of the opposing player
value_score = -child.value()
else:
value_score = 0
return value_score + prior_score
class Node:
def __init__(self, prior, to_play):
self.visit_count = 0
self.to_play = to_play
self.prior = prior
self.value_sum = 0
self.children = {}
self.state = None
def expanded(self):
return len(self.children) > 0
def value(self):
if self.visit_count == 0:
return 0
return self.value_sum / self.visit_count
def select_action(self, temperature):
"""
Select action according to the visit count distribution and the temperature.
"""
visit_counts = np.array([child.visit_count for child in self.children.values()])
actions = [action for action in self.children.keys()]
if temperature == 0:
action = actions[np.argmax(visit_counts)]
elif temperature == float("inf"):
action = np.random.choice(actions)
else:
# See paper appendix Data Generation
visit_count_distribution = visit_counts ** (1 / temperature)
visit_count_distribution = visit_count_distribution / sum(visit_count_distribution)
action = np.random.choice(actions, p=visit_count_distribution)
return action
def select_child(self):
"""
Select the child with the highest UCB score.
"""
best_score = -np.inf
best_action = -1
best_child = None
for action, child in self.children.items():
score = ucb_score(self, child)
if score > best_score:
best_score = score
best_action = action
best_child = child
return best_action, best_child
def expand(self, state, to_play, action_probs):
"""
We expand a node and keep track of the prior policy probability given by neural network
"""
self.to_play = to_play
self.state = state
for a, prob in enumerate(action_probs):
if prob != 0:
self.children[a] = Node(prior=prob, to_play=self.to_play * -1)
def __repr__(self):
"""
Debugger pretty print node info
"""
prior = "{0:.2f}".format(self.prior)
return "{} Prior: {} Count: {} Value: {}".format(self.state.__str__(), prior, self.visit_count, self.value())
class MCTS:
def __init__(self, game, model, args):
self.game = game
self.model = model
self.args = args
def _masked_policy(self, state, model):
encoded_state = self.game.encode_state(state)
action_probs, value = model.predict(encoded_state)
valid_moves = np.array(self.game.get_valid_moves(state), dtype=np.float32)
action_probs = action_probs * valid_moves
total_prob = np.sum(action_probs)
total_valid = np.sum(valid_moves)
if total_valid <= 0:
return valid_moves, float(value)
if total_prob <= 0:
action_probs = valid_moves / total_valid
else:
action_probs /= total_prob
return action_probs, float(value)
def _add_exploration_noise(self, node):
alpha = self.args.get('root_dirichlet_alpha')
fraction = self.args.get('root_exploration_fraction')
if alpha is None or fraction is None or not node.children:
return
actions = list(node.children.keys())
noise = np.random.dirichlet([alpha] * len(actions))
for action, sample in zip(actions, noise):
child = node.children[action]
child.prior = child.prior * (1 - fraction) + sample * fraction
def run(self, model, state, to_play):
model = model or self.model
root = Node(0, to_play)
root.state = state
reward = self.game.get_reward_for_player(state, player=1)
if reward is not None:
root.value_sum = float(reward)
root.visit_count = 1
return root
# EXPAND root
action_probs, value = self._masked_policy(state, model)
root.expand(state, to_play, action_probs)
if not root.children:
root.value_sum = float(value)
root.visit_count = 1
return root
self._add_exploration_noise(root)
for _ in range(self.args['num_simulations']):
node = root
search_path = [node]
# SELECT
xp = False
while node.expanded():
action, node = node.select_child()
if node == None:
parent = search_path[-1]
self.backpropagate(search_path, value, parent.to_play * -1)
xp = True
break
search_path.append(node)
if xp:
continue
parent = search_path[-2]
state = parent.state
# Now we're at a leaf node and we would like to expand
# Players always play from their own perspective
next_state, _ = self.game.get_next_state(state, player=1, action=action)
# Get the board from the perspective of the other player
next_state_data, next_state_inner = next_state
next_state = (self.game.get_canonical_board_data(next_state_data, player=-1), next_state_inner)
# The value of the new state from the perspective of the other player
value = self.game.get_reward_for_player(next_state, player=1)
if value is None:
# If the game has not ended:
# EXPAND
action_probs, value = self._masked_policy(next_state, model)
node.expand(next_state, parent.to_play * -1, action_probs)
self.backpropagate(search_path, float(value), parent.to_play * -1)
return root
def backpropagate(self, search_path, value, to_play):
"""
At the end of a simulation, we propagate the evaluation all the way up the tree
to the root.
"""
for node in reversed(search_path):
node.value_sum += value if node.to_play == to_play else -value
node.visit_count += 1