191 lines
6.5 KiB
Python
191 lines
6.5 KiB
Python
import torch
|
|
import math
|
|
import numpy as np
|
|
|
|
|
|
def ucb_score(parent, child):
|
|
"""
|
|
The score for an action that would transition between the parent and child.
|
|
"""
|
|
prior_score = child.prior * math.sqrt(parent.visit_count) / (child.visit_count + 1)
|
|
if child.visit_count > 0:
|
|
# The value of the child is from the perspective of the opposing player
|
|
value_score = -child.value()
|
|
else:
|
|
value_score = 0
|
|
|
|
return value_score + prior_score
|
|
|
|
|
|
class Node:
|
|
def __init__(self, prior, to_play):
|
|
self.visit_count = 0
|
|
self.to_play = to_play
|
|
self.prior = prior
|
|
self.value_sum = 0
|
|
self.children = {}
|
|
self.state = None
|
|
|
|
def expanded(self):
|
|
return len(self.children) > 0
|
|
|
|
def value(self):
|
|
if self.visit_count == 0:
|
|
return 0
|
|
return self.value_sum / self.visit_count
|
|
|
|
def select_action(self, temperature):
|
|
"""
|
|
Select action according to the visit count distribution and the temperature.
|
|
"""
|
|
visit_counts = np.array([child.visit_count for child in self.children.values()])
|
|
actions = [action for action in self.children.keys()]
|
|
if temperature == 0:
|
|
action = actions[np.argmax(visit_counts)]
|
|
elif temperature == float("inf"):
|
|
action = np.random.choice(actions)
|
|
else:
|
|
# See paper appendix Data Generation
|
|
visit_count_distribution = visit_counts ** (1 / temperature)
|
|
visit_count_distribution = visit_count_distribution / sum(visit_count_distribution)
|
|
action = np.random.choice(actions, p=visit_count_distribution)
|
|
|
|
return action
|
|
|
|
def select_child(self):
|
|
"""
|
|
Select the child with the highest UCB score.
|
|
"""
|
|
best_score = -np.inf
|
|
best_action = -1
|
|
best_child = None
|
|
|
|
for action, child in self.children.items():
|
|
score = ucb_score(self, child)
|
|
if score > best_score:
|
|
best_score = score
|
|
best_action = action
|
|
best_child = child
|
|
|
|
return best_action, best_child
|
|
|
|
def expand(self, state, to_play, action_probs):
|
|
"""
|
|
We expand a node and keep track of the prior policy probability given by neural network
|
|
"""
|
|
self.to_play = to_play
|
|
self.state = state
|
|
for a, prob in enumerate(action_probs):
|
|
if prob != 0:
|
|
self.children[a] = Node(prior=prob, to_play=self.to_play * -1)
|
|
|
|
def __repr__(self):
|
|
"""
|
|
Debugger pretty print node info
|
|
"""
|
|
prior = "{0:.2f}".format(self.prior)
|
|
return "{} Prior: {} Count: {} Value: {}".format(self.state.__str__(), prior, self.visit_count, self.value())
|
|
|
|
|
|
class MCTS:
|
|
|
|
def __init__(self, game, model, args):
|
|
self.game = game
|
|
self.model = model
|
|
self.args = args
|
|
|
|
def _masked_policy(self, state, model):
|
|
encoded_state = self.game.encode_state(state)
|
|
action_probs, value = model.predict(encoded_state)
|
|
valid_moves = np.array(self.game.get_valid_moves(state), dtype=np.float32)
|
|
action_probs = action_probs * valid_moves
|
|
total_prob = np.sum(action_probs)
|
|
total_valid = np.sum(valid_moves)
|
|
if total_valid <= 0:
|
|
return valid_moves, float(value)
|
|
if total_prob <= 0:
|
|
action_probs = valid_moves / total_valid
|
|
else:
|
|
action_probs /= total_prob
|
|
return action_probs, float(value)
|
|
|
|
def _add_exploration_noise(self, node):
|
|
alpha = self.args.get('root_dirichlet_alpha')
|
|
fraction = self.args.get('root_exploration_fraction')
|
|
if alpha is None or fraction is None or not node.children:
|
|
return
|
|
|
|
actions = list(node.children.keys())
|
|
noise = np.random.dirichlet([alpha] * len(actions))
|
|
for action, sample in zip(actions, noise):
|
|
child = node.children[action]
|
|
child.prior = child.prior * (1 - fraction) + sample * fraction
|
|
|
|
def run(self, model, state, to_play):
|
|
model = model or self.model
|
|
|
|
root = Node(0, to_play)
|
|
root.state = state
|
|
|
|
reward = self.game.get_reward_for_player(state, player=1)
|
|
if reward is not None:
|
|
root.value_sum = float(reward)
|
|
root.visit_count = 1
|
|
return root
|
|
|
|
# EXPAND root
|
|
action_probs, value = self._masked_policy(state, model)
|
|
root.expand(state, to_play, action_probs)
|
|
if not root.children:
|
|
root.value_sum = float(value)
|
|
root.visit_count = 1
|
|
return root
|
|
self._add_exploration_noise(root)
|
|
|
|
for _ in range(self.args['num_simulations']):
|
|
node = root
|
|
search_path = [node]
|
|
|
|
# SELECT
|
|
xp = False
|
|
while node.expanded():
|
|
action, node = node.select_child()
|
|
if node == None:
|
|
parent = search_path[-1]
|
|
self.backpropagate(search_path, value, parent.to_play * -1)
|
|
xp = True
|
|
break
|
|
search_path.append(node)
|
|
if xp:
|
|
continue
|
|
parent = search_path[-2]
|
|
state = parent.state
|
|
# Now we're at a leaf node and we would like to expand
|
|
# Players always play from their own perspective
|
|
next_state, _ = self.game.get_next_state(state, player=1, action=action)
|
|
# Get the board from the perspective of the other player
|
|
next_state_data, next_state_inner = next_state
|
|
|
|
next_state = (self.game.get_canonical_board_data(next_state_data, player=-1), next_state_inner)
|
|
|
|
# The value of the new state from the perspective of the other player
|
|
value = self.game.get_reward_for_player(next_state, player=1)
|
|
if value is None:
|
|
# If the game has not ended:
|
|
# EXPAND
|
|
action_probs, value = self._masked_policy(next_state, model)
|
|
node.expand(next_state, parent.to_play * -1, action_probs)
|
|
|
|
self.backpropagate(search_path, float(value), parent.to_play * -1)
|
|
|
|
return root
|
|
|
|
def backpropagate(self, search_path, value, to_play):
|
|
"""
|
|
At the end of a simulation, we propagate the evaluation all the way up the tree
|
|
to the root.
|
|
"""
|
|
for node in reversed(search_path):
|
|
node.value_sum += value if node.to_play == to_play else -value
|
|
node.visit_count += 1
|