import torch import math import numpy as np def ucb_score(parent, child): """ The score for an action that would transition between the parent and child. """ prior_score = child.prior * math.sqrt(parent.visit_count) / (child.visit_count + 1) if child.visit_count > 0: # The value of the child is from the perspective of the opposing player value_score = -child.value() else: value_score = 0 return value_score + prior_score class Node: def __init__(self, prior, to_play): self.visit_count = 0 self.to_play = to_play self.prior = prior self.value_sum = 0 self.children = {} self.state = None def expanded(self): return len(self.children) > 0 def value(self): if self.visit_count == 0: return 0 return self.value_sum / self.visit_count def select_action(self, temperature): """ Select action according to the visit count distribution and the temperature. """ visit_counts = np.array([child.visit_count for child in self.children.values()]) actions = [action for action in self.children.keys()] if temperature == 0: action = actions[np.argmax(visit_counts)] elif temperature == float("inf"): action = np.random.choice(actions) else: # See paper appendix Data Generation visit_count_distribution = visit_counts ** (1 / temperature) visit_count_distribution = visit_count_distribution / sum(visit_count_distribution) action = np.random.choice(actions, p=visit_count_distribution) return action def select_child(self): """ Select the child with the highest UCB score. """ best_score = -np.inf best_action = -1 best_child = None for action, child in self.children.items(): score = ucb_score(self, child) if score > best_score: best_score = score best_action = action best_child = child return best_action, best_child def expand(self, state, to_play, action_probs): """ We expand a node and keep track of the prior policy probability given by neural network """ self.to_play = to_play self.state = state for a, prob in enumerate(action_probs): if prob != 0: self.children[a] = Node(prior=prob, to_play=self.to_play * -1) def __repr__(self): """ Debugger pretty print node info """ prior = "{0:.2f}".format(self.prior) return "{} Prior: {} Count: {} Value: {}".format(self.state.__str__(), prior, self.visit_count, self.value()) class MCTS: def __init__(self, game, model, args): self.game = game self.model = model self.args = args def _masked_policy(self, state, model): encoded_state = self.game.encode_state(state) action_probs, value = model.predict(encoded_state) valid_moves = np.array(self.game.get_valid_moves(state), dtype=np.float32) action_probs = action_probs * valid_moves total_prob = np.sum(action_probs) total_valid = np.sum(valid_moves) if total_valid <= 0: return valid_moves, float(value) if total_prob <= 0: action_probs = valid_moves / total_valid else: action_probs /= total_prob return action_probs, float(value) def _add_exploration_noise(self, node): alpha = self.args.get('root_dirichlet_alpha') fraction = self.args.get('root_exploration_fraction') if alpha is None or fraction is None or not node.children: return actions = list(node.children.keys()) noise = np.random.dirichlet([alpha] * len(actions)) for action, sample in zip(actions, noise): child = node.children[action] child.prior = child.prior * (1 - fraction) + sample * fraction def run(self, model, state, to_play): model = model or self.model root = Node(0, to_play) root.state = state reward = self.game.get_reward_for_player(state, player=1) if reward is not None: root.value_sum = float(reward) root.visit_count = 1 return root # EXPAND root action_probs, value = self._masked_policy(state, model) root.expand(state, to_play, action_probs) if not root.children: root.value_sum = float(value) root.visit_count = 1 return root self._add_exploration_noise(root) for _ in range(self.args['num_simulations']): node = root search_path = [node] # SELECT xp = False while node.expanded(): action, node = node.select_child() if node == None: parent = search_path[-1] self.backpropagate(search_path, value, parent.to_play * -1) xp = True break search_path.append(node) if xp: continue parent = search_path[-2] state = parent.state # Now we're at a leaf node and we would like to expand # Players always play from their own perspective next_state, _ = self.game.get_next_state(state, player=1, action=action) # Get the board from the perspective of the other player next_state_data, next_state_inner = next_state next_state = (self.game.get_canonical_board_data(next_state_data, player=-1), next_state_inner) # The value of the new state from the perspective of the other player value = self.game.get_reward_for_player(next_state, player=1) if value is None: # If the game has not ended: # EXPAND action_probs, value = self._masked_policy(next_state, model) node.expand(next_state, parent.to_play * -1, action_probs) self.backpropagate(search_path, float(value), parent.to_play * -1) return root def backpropagate(self, search_path, value, to_play): """ At the end of a simulation, we propagate the evaluation all the way up the tree to the root. """ for node in reversed(search_path): node.value_sum += value if node.to_play == to_play else -value node.visit_count += 1