From 29fc731e6c9074e49761ae9ece7f8221d5c294d8 Mon Sep 17 00:00:00 2001 From: BinarySandia04 Date: Mon, 23 Mar 2026 21:19:29 +0100 Subject: [PATCH] First commit --- .gitignore | 4 + README.md | 60 +++++++++ game.py | 196 +++++++++++++++++++++++++++ mcts.py | 190 ++++++++++++++++++++++++++ model.py | 81 +++++++++++ run.py | 384 +++++++++++++++++++++++++++++++++++++++++++++++++++++ trainer.py | 258 +++++++++++++++++++++++++++++++++++ 7 files changed, 1173 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 game.py create mode 100644 mcts.py create mode 100644 model.py create mode 100644 run.py create mode 100644 trainer.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3f561f7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +__pycache__/ +.ipynb_checkpoints/ +*.pth + diff --git a/README.md b/README.md new file mode 100644 index 0000000..81b4559 --- /dev/null +++ b/README.md @@ -0,0 +1,60 @@ +# Ultimate Tic Tac Toe Deep Learning Bot + +**Usage** +Run `python run.py --help` for help. + +**Shared flags** +- --device: Torch device string. If omitted, it auto-picks "cuda" when available, otherwise "cpu". +- --checkpoint: Path to the model checkpoint file. Default is latest.pth. It is loaded for eval, play, and checkpoint-based arena, +and used as the save path for accepted training checkpoints. + +**Training parameters** +- --resume: Loads model and optimizer state from --checkpoint before continuing training. +- --num-simulations default 100: MCTS rollouts per move during self-play. Higher is stronger/slower. +- --num-iters default 50: Number of outer training iterations. Each iteration generates new self-play games, trains, then arena-tests + the new model. +- --num-eps default 20: Self-play games per iteration. +- --epochs default 5: Passes over the current replay-buffer training set per iteration. +- --batch-size default 64: Mini-batch size for gradient updates. +- --lr default 5e-4: Adam learning rate. +- --weight-decay default 1e-4: Adam weight decay (L2-style regularization). +- --replay-buffer-size default 50000: Maximum number of training examples retained across iterations. Older examples are dropped. +- --value-loss-weight default 1.0: Multiplier on the value-head loss in total training loss. Total loss is policy_KL + + value_loss_weight * value_loss. +- --grad-clip-norm default 5.0: Global gradient norm clipping threshold before optimizer step. +- --temperature-threshold default 10: In self-play, moves before this step use stochastic sampling from MCTS visit counts; later + moves use greedy selection. +- --root-dirichlet-alpha default 0.3: Dirichlet noise alpha added to root priors during self-play MCTS to force exploration. +- --root-exploration-fraction default 0.25: How much of that root prior is replaced by Dirichlet noise. +- --arena-compare-games default 6: Number of head-to-head games between candidate and previous model after each iteration. If <= 0, + every candidate is accepted. +- --arena-accept-threshold default 0.55: Minimum average points needed in arena to keep the new model. Win = 1, draw = 0.5. +- --arena-compare-simulations default 8: MCTS simulations per move during those arena comparison games. Separate from self-play + --num-simulations. + +**Evaluation parameters** + +- --moves default "": Comma-separated move list to reach a position from the starting board, e.g. 0,10,4. +- --top-k default 10: How many highest-probability legal moves to print from the model policy. +- --with-mcts: Also run MCTS on that position and print the best move, instead of only raw network policy/value. +- --num-simulations default 100: Only matters with --with-mcts; controls MCTS search depth for that evaluation. + +**Play parameters** + +- --human-player default 1: Which side you control. 1 means X, -1 means O. +- --num-simulations default 100: MCTS simulations the AI uses for each move. + +**Arena parameters** +- --games default 20: Number of matches to run. +- --num-simulations default 100: MCTS simulations per move for checkpoint-based players. +- --x-player / --o-player: Either checkpoint or random. Chooses the agent type for each side. +- --x-checkpoint / --o-checkpoint: Checkpoint path for that side when its player type is checkpoint. Ignored for random. + +A few practical examples: + +```bash +python run.py train --num-iters 100 --num-eps 50 --resume +python run.py eval --checkpoint latest.pth --moves 0,10,4 --with-mcts --num-simulations 200 +python run.py play --human-player -1 --num-simulations 300 +python run.py arena --games 50 --x-player checkpoint --o-player random +``` diff --git a/game.py b/game.py new file mode 100644 index 0000000..aa5e64e --- /dev/null +++ b/game.py @@ -0,0 +1,196 @@ +import numpy as np + +WIN_PATTERNS = [ + (0, 1, 2), + (3, 4, 5), + (6, 7, 8), + (0, 3, 6), + (1, 4, 7), + (2, 5, 8), + (0, 4, 8), + (2, 4, 6), +] + +class UltimateTicTacToe: + """ + A very, very simple game of ConnectX in which we have: + rows: 1 + columns: 4 + winNumber: 2 + """ + + def __init__(self): + self.cells = 81 + self.board_width = 9 + self.state_planes = 9 + + def get_init_board(self): + b = np.zeros((self.cells,), dtype=int) + return (b, None) + + def get_board_size(self): + return (self.state_planes, self.board_width, self.board_width) + + def get_action_size(self): + return self.cells + + def get_next_state(self, board, player, action, verify_move=False): + if verify_move: + if self.get_valid_moves(board)[action] == 0: + return False + new_board_data = np.copy(board[0]) + new_board_data[action] = player + + next_board = ((action // 9) % 3) * 3 + (action % 3) + next_board = next_board if not self.is_board_full(new_board_data, next_board) else None + + # Return the new game, but + # change the perspective of the game with negative + return ((new_board_data, next_board), -player) + + def is_board_full(self, board_data, next_board): + return self._is_small_board_win(board_data, next_board, 1) or self._is_small_board_win(board_data, next_board, -1) or self._is_board_full(board_data, next_board) + + def _small_board_cells(self, inner_board_idx): + row_block = inner_board_idx // 3 + col_block = inner_board_idx % 3 + + base = row_block * 27 + col_block * 3 + + return [ + base, base + 1, base + 2, + base + 9, base + 10, base + 11, + base + 18, base + 19, base + 20 + ] + + def _is_board_full(self, board_data, next_board): + # Check if it is literally full + cells = self._small_board_cells(next_board) + + for a in cells: + if board_data[a] == 0: + return False + return True + + def _is_playable_small_board(self, board_data, inner_board_idx): + return not self.is_board_full(board_data, inner_board_idx) + + def has_legal_moves(self, board): + valid_moves = self.get_valid_moves(board) + for i in valid_moves: + if i == 1: + return True + return False + + def get_valid_moves(self, board): + # All moves are invalid by default + board_data, active_board = board + valid_moves = [0] * self.get_action_size() + + if active_board is not None and not self._is_playable_small_board(board_data, active_board): + active_board = None + + if active_board is None: + playable_boards = [ + inner_board_idx + for inner_board_idx in range(9) + if self._is_playable_small_board(board_data, inner_board_idx) + ] + for inner_board_idx in playable_boards: + for index in self._small_board_cells(inner_board_idx): + if board_data[index] == 0: + valid_moves[index] = 1 + else: + for index in self._small_board_cells(active_board): + if board_data[index] == 0: + valid_moves[index] = 1 + + return valid_moves + + def _is_small_board_win(self, board_data, inner_board_idx, player): + cells = self._small_board_cells(inner_board_idx) + + for a, b, c in WIN_PATTERNS: + if board_data[cells[a]] == board_data[cells[b]] == board_data[cells[c]] == player: + return True + + return False + + def is_win(self, board, player): + board_data, _ = board + won = [self._is_small_board_win(board_data, i, player) for i in range(9)] + + # Check if any winning combination is all 1s + for a, b, c in WIN_PATTERNS: + if won[a] and won[b] and won[c]: + return True + + return False + + def get_reward_for_player(self, board, player): + # return None if not ended, 1 if player 1 wins, -1 if player 1 lost + + if self.is_win(board, player): + return 1 + if self.is_win(board, -player): + return -1 + if self.has_legal_moves(board): + return None + + return 0 + + def get_canonical_board_data(self, board_data, player): + return player * board_data + + def _small_board_mask(self, inner_board_idx): + mask = np.zeros((self.board_width, self.board_width), dtype=np.float32) + for index in self._small_board_cells(inner_board_idx): + row = index // self.board_width + col = index % self.board_width + mask[row, col] = 1.0 + return mask + + def encode_state(self, board): + board_data, active_board = board + board_grid = board_data.reshape(self.board_width, self.board_width) + + current_stones = (board_grid == 1).astype(np.float32) + opponent_stones = (board_grid == -1).astype(np.float32) + empty_cells = (board_grid == 0).astype(np.float32) + legal_moves = np.array(self.get_valid_moves(board), dtype=np.float32).reshape(self.board_width, self.board_width) + + active_board_mask = np.zeros((self.board_width, self.board_width), dtype=np.float32) + if active_board is not None and self._is_playable_small_board(board_data, active_board): + active_board_mask = self._small_board_mask(active_board) + + current_won_boards = np.zeros((self.board_width, self.board_width), dtype=np.float32) + opponent_won_boards = np.zeros((self.board_width, self.board_width), dtype=np.float32) + playable_boards = np.zeros((self.board_width, self.board_width), dtype=np.float32) + + for inner_board_idx in range(9): + board_mask = self._small_board_mask(inner_board_idx) + if self._is_small_board_win(board_data, inner_board_idx, 1): + current_won_boards += board_mask + elif self._is_small_board_win(board_data, inner_board_idx, -1): + opponent_won_boards += board_mask + + if self._is_playable_small_board(board_data, inner_board_idx): + playable_boards += board_mask + + move_count = np.count_nonzero(board_data) / self.cells + move_count_plane = np.full((self.board_width, self.board_width), move_count, dtype=np.float32) + + return np.stack( + ( + current_stones, + opponent_stones, + empty_cells, + legal_moves, + active_board_mask, + current_won_boards, + opponent_won_boards, + playable_boards, + move_count_plane, + ), + axis=0, + ) diff --git a/mcts.py b/mcts.py new file mode 100644 index 0000000..94aecdd --- /dev/null +++ b/mcts.py @@ -0,0 +1,190 @@ +import torch +import math +import numpy as np + + +def ucb_score(parent, child): + """ + The score for an action that would transition between the parent and child. + """ + prior_score = child.prior * math.sqrt(parent.visit_count) / (child.visit_count + 1) + if child.visit_count > 0: + # The value of the child is from the perspective of the opposing player + value_score = -child.value() + else: + value_score = 0 + + return value_score + prior_score + + +class Node: + def __init__(self, prior, to_play): + self.visit_count = 0 + self.to_play = to_play + self.prior = prior + self.value_sum = 0 + self.children = {} + self.state = None + + def expanded(self): + return len(self.children) > 0 + + def value(self): + if self.visit_count == 0: + return 0 + return self.value_sum / self.visit_count + + def select_action(self, temperature): + """ + Select action according to the visit count distribution and the temperature. + """ + visit_counts = np.array([child.visit_count for child in self.children.values()]) + actions = [action for action in self.children.keys()] + if temperature == 0: + action = actions[np.argmax(visit_counts)] + elif temperature == float("inf"): + action = np.random.choice(actions) + else: + # See paper appendix Data Generation + visit_count_distribution = visit_counts ** (1 / temperature) + visit_count_distribution = visit_count_distribution / sum(visit_count_distribution) + action = np.random.choice(actions, p=visit_count_distribution) + + return action + + def select_child(self): + """ + Select the child with the highest UCB score. + """ + best_score = -np.inf + best_action = -1 + best_child = None + + for action, child in self.children.items(): + score = ucb_score(self, child) + if score > best_score: + best_score = score + best_action = action + best_child = child + + return best_action, best_child + + def expand(self, state, to_play, action_probs): + """ + We expand a node and keep track of the prior policy probability given by neural network + """ + self.to_play = to_play + self.state = state + for a, prob in enumerate(action_probs): + if prob != 0: + self.children[a] = Node(prior=prob, to_play=self.to_play * -1) + + def __repr__(self): + """ + Debugger pretty print node info + """ + prior = "{0:.2f}".format(self.prior) + return "{} Prior: {} Count: {} Value: {}".format(self.state.__str__(), prior, self.visit_count, self.value()) + + +class MCTS: + + def __init__(self, game, model, args): + self.game = game + self.model = model + self.args = args + + def _masked_policy(self, state, model): + encoded_state = self.game.encode_state(state) + action_probs, value = model.predict(encoded_state) + valid_moves = np.array(self.game.get_valid_moves(state), dtype=np.float32) + action_probs = action_probs * valid_moves + total_prob = np.sum(action_probs) + total_valid = np.sum(valid_moves) + if total_valid <= 0: + return valid_moves, float(value) + if total_prob <= 0: + action_probs = valid_moves / total_valid + else: + action_probs /= total_prob + return action_probs, float(value) + + def _add_exploration_noise(self, node): + alpha = self.args.get('root_dirichlet_alpha') + fraction = self.args.get('root_exploration_fraction') + if alpha is None or fraction is None or not node.children: + return + + actions = list(node.children.keys()) + noise = np.random.dirichlet([alpha] * len(actions)) + for action, sample in zip(actions, noise): + child = node.children[action] + child.prior = child.prior * (1 - fraction) + sample * fraction + + def run(self, model, state, to_play): + model = model or self.model + + root = Node(0, to_play) + root.state = state + + reward = self.game.get_reward_for_player(state, player=1) + if reward is not None: + root.value_sum = float(reward) + root.visit_count = 1 + return root + + # EXPAND root + action_probs, value = self._masked_policy(state, model) + root.expand(state, to_play, action_probs) + if not root.children: + root.value_sum = float(value) + root.visit_count = 1 + return root + self._add_exploration_noise(root) + + for _ in range(self.args['num_simulations']): + node = root + search_path = [node] + + # SELECT + xp = False + while node.expanded(): + action, node = node.select_child() + if node == None: + parent = search_path[-1] + self.backpropagate(search_path, value, parent.to_play * -1) + xp = True + break + search_path.append(node) + if xp: + continue + parent = search_path[-2] + state = parent.state + # Now we're at a leaf node and we would like to expand + # Players always play from their own perspective + next_state, _ = self.game.get_next_state(state, player=1, action=action) + # Get the board from the perspective of the other player + next_state_data, next_state_inner = next_state + + next_state = (self.game.get_canonical_board_data(next_state_data, player=-1), next_state_inner) + + # The value of the new state from the perspective of the other player + value = self.game.get_reward_for_player(next_state, player=1) + if value is None: + # If the game has not ended: + # EXPAND + action_probs, value = self._masked_policy(next_state, model) + node.expand(next_state, parent.to_play * -1, action_probs) + + self.backpropagate(search_path, float(value), parent.to_play * -1) + + return root + + def backpropagate(self, search_path, value, to_play): + """ + At the end of a simulation, we propagate the evaluation all the way up the tree + to the root. + """ + for node in reversed(search_path): + node.value_sum += value if node.to_play == to_play else -value + node.visit_count += 1 diff --git a/model.py b/model.py new file mode 100644 index 0000000..59ab882 --- /dev/null +++ b/model.py @@ -0,0 +1,81 @@ +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(channels) + self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(channels) + + def forward(self, x): + residual = x + x = F.relu(self.bn1(self.conv1(x))) + x = self.bn2(self.conv2(x)) + return F.relu(x + residual) + + +class UltimateTicTacToeModel(nn.Module): + def __init__(self, board_size, action_size, device, channels=64, num_blocks=6): + super().__init__() + + self.action_size = action_size + self.input_shape = board_size + self.input_channels = board_size[0] + self.board_height = board_size[1] + self.board_width = board_size[2] + self.device = torch.device(device) + + self.stem = nn.Sequential( + nn.Conv2d(self.input_channels, channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(channels), + nn.ReLU(inplace=True), + ) + self.residual_tower = nn.Sequential(*(ResidualBlock(channels) for _ in range(num_blocks))) + + self.policy_head = nn.Sequential( + nn.Conv2d(channels, 32, kernel_size=1, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.policy_fc = nn.Linear(32 * self.board_height * self.board_width, self.action_size) + + self.value_head = nn.Sequential( + nn.Conv2d(channels, 32, kernel_size=1, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ) + self.value_fc1 = nn.Linear(32 * self.board_height * self.board_width, 128) + self.value_fc2 = nn.Linear(128, 1) + + self.to(self.device) + + def forward(self, x): + x = x.view(-1, *self.input_shape) + x = self.stem(x) + x = self.residual_tower(x) + + policy = self.policy_head(x) + policy = torch.flatten(policy, 1) + policy = self.policy_fc(policy) + + value = self.value_head(x) + value = torch.flatten(value, 1) + value = F.relu(self.value_fc1(value)) + value = torch.tanh(self.value_fc2(value)) + + return F.softmax(policy, dim=1), value + + def predict(self, board): + board = torch.as_tensor(board, dtype=torch.float32, device=self.device) + board = board.view(1, *self.input_shape) + self.eval() + with torch.no_grad(): + pi, v = self.forward(board) + + return pi.detach().cpu().numpy()[0], float(v.item()) diff --git a/run.py b/run.py new file mode 100644 index 0000000..aa1c1c7 --- /dev/null +++ b/run.py @@ -0,0 +1,384 @@ +import argparse +from pathlib import Path + +import numpy as np +import torch + +from game import UltimateTicTacToe +from mcts import MCTS +from model import UltimateTicTacToeModel +from trainer import Trainer + + +DEFAULT_ARGS = { + "num_simulations": 100, + "numIters": 50, + "numEps": 20, + "epochs": 5, + "batch_size": 64, + "lr": 5e-4, + "weight_decay": 1e-4, + "replay_buffer_size": 50000, + "value_loss_weight": 1.0, + "grad_clip_norm": 5.0, + "checkpoint_path": "latest.pth", + "temperature_threshold": 10, + "root_dirichlet_alpha": 0.3, + "root_exploration_fraction": 0.25, + "arena_compare_games": 6, + "arena_accept_threshold": 0.55, + "arena_compare_simulations": 8, +} + + +def get_device(device_arg): + if device_arg: + return device_arg + return "cuda" if torch.cuda.is_available() else "cpu" + + +def build_model(game, device): + return UltimateTicTacToeModel( + game.get_board_size(), + game.get_action_size(), + device, + ) + + +def load_checkpoint(model, checkpoint_path, device, optimizer=None, required=True): + checkpoint = Path(checkpoint_path) + if not checkpoint.exists(): + if required: + raise FileNotFoundError(f"Checkpoint not found: {checkpoint}") + return False + + state = torch.load(checkpoint, map_location=device) + model.load_state_dict(state["state_dict"]) + if optimizer is not None and "optimizer_state_dict" in state: + optimizer.load_state_dict(state["optimizer_state_dict"]) + model.eval() + return True + + +def canonical_state(game, state, player): + board_data, active_board = state + return (game.get_canonical_board_data(board_data, player), active_board) + + +def apply_moves(game, moves): + state = game.get_init_board() + player = 1 + for action in moves: + next_state = game.get_next_state(state, player, action, verify_move=True) + if next_state is False: + raise ValueError(f"Illegal move in sequence: {action}") + state, player = next_state + return state, player + + +def format_board(board_data): + symbols = {1: "X", -1: "O", 0: "."} + rows = [] + for row in range(9): + cells = [symbols[int(board_data[row * 9 + col])] for col in range(9)] + groups = [" ".join(cells[idx:idx + 3]) for idx in (0, 3, 6)] + rows.append(" | ".join(groups)) + if row in (2, 5): + rows.append("-" * 23) + return "\n".join(rows) + + +def top_policy_moves(policy, limit): + ranked = np.argsort(policy)[::-1][:limit] + return [(int(action), float(policy[action])) for action in ranked] + + +def parse_moves(text): + if not text: + return [] + return [int(part.strip()) for part in text.split(",") if part.strip()] + + +def parse_action(text): + raw = text.strip().replace(",", " ").split() + if len(raw) == 1: + action = int(raw[0]) + elif len(raw) == 2: + row, col = (int(value) for value in raw) + if not (0 <= row < 9 and 0 <= col < 9): + raise ValueError("Row and column must be in [0, 8].") + action = row * 9 + col + else: + raise ValueError("Enter either a flat move index or 'row col'.") + if not (0 <= action < 81): + raise ValueError("Move index must be in [0, 80].") + return action + + +def scalar_value(value): + return float(np.asarray(value).reshape(-1)[0]) + + +def train_command(args): + device = get_device(args.device) + game = UltimateTicTacToe() + model = build_model(game, device) + + train_args = dict(DEFAULT_ARGS) + train_args.update( + { + "num_simulations": args.num_simulations, + "numIters": args.num_iters, + "numEps": args.num_eps, + "epochs": args.epochs, + "batch_size": args.batch_size, + "lr": args.lr, + "weight_decay": args.weight_decay, + "replay_buffer_size": args.replay_buffer_size, + "value_loss_weight": args.value_loss_weight, + "grad_clip_norm": args.grad_clip_norm, + "checkpoint_path": args.checkpoint, + "temperature_threshold": args.temperature_threshold, + "root_dirichlet_alpha": args.root_dirichlet_alpha, + "root_exploration_fraction": args.root_exploration_fraction, + "arena_compare_games": args.arena_compare_games, + "arena_accept_threshold": args.arena_accept_threshold, + "arena_compare_simulations": args.arena_compare_simulations, + } + ) + + trainer = Trainer(game, model, train_args) + if args.resume: + load_checkpoint(model, args.checkpoint, device, optimizer=trainer.optimizer) + trainer.learn() + + +def eval_command(args): + device = get_device(args.device) + game = UltimateTicTacToe() + model = build_model(game, device) + load_checkpoint(model, args.checkpoint, device) + + moves = parse_moves(args.moves) + state, player = apply_moves(game, moves) + current_state = canonical_state(game, state, player) + encoded = game.encode_state(current_state) + policy, value = model.predict(encoded) + legal_mask = np.array(game.get_valid_moves(state), dtype=np.float32) + policy = policy * legal_mask + if policy.sum() > 0: + policy = policy / policy.sum() + + print("Board:") + print(format_board(state[0])) + print() + print(f"Side to move: {'X' if player == 1 else 'O'}") + print(f"Active small board: {state[1]}") + print(f"Model value: {scalar_value(value):.4f}") + print("Top policy moves:") + for action, prob in top_policy_moves(policy, args.top_k): + print(f" {action:2d} -> {prob:.4f}") + + if args.with_mcts: + mcts_args = dict(DEFAULT_ARGS) + mcts_args.update( + { + "num_simulations": args.num_simulations, + "root_dirichlet_alpha": None, + "root_exploration_fraction": None, + } + ) + root = MCTS(game, model, mcts_args).run(model, current_state, to_play=1) + action = root.select_action(temperature=0) + print(f"MCTS best move: {action}") + + +def ai_action(game, model, state, player, num_simulations): + current_state = canonical_state(game, state, player) + mcts_args = dict(DEFAULT_ARGS) + mcts_args.update( + { + "num_simulations": num_simulations, + "root_dirichlet_alpha": None, + "root_exploration_fraction": None, + } + ) + root = MCTS(game, model, mcts_args).run(model, current_state, to_play=1) + return root.select_action(temperature=0) + + +def random_action(game, state): + legal_actions = [index for index, allowed in enumerate(game.get_valid_moves(state)) if allowed] + if not legal_actions: + raise ValueError("No legal actions available.") + return int(np.random.choice(legal_actions)) + + +def load_player_model(game, checkpoint, device): + model = build_model(game, device) + load_checkpoint(model, checkpoint, device) + return model + + +def choose_action(game, player_kind, model, state, player, num_simulations): + if player_kind == "random": + return random_action(game, state) + return ai_action(game, model, state, player, num_simulations) + + +def play_match(game, x_kind, x_model, o_kind, o_model, num_simulations): + state = game.get_init_board() + player = 1 + + while True: + reward = game.get_reward_for_player(state, player) + if reward is not None: + if reward == 0: + return 0 + return player if reward == 1 else -player + + if player == 1: + action = choose_action(game, x_kind, x_model, state, player, num_simulations) + else: + action = choose_action(game, o_kind, o_model, state, player, num_simulations) + state, player = game.get_next_state(state, player, action) + + +def arena_command(args): + device = get_device(args.device) + game = UltimateTicTacToe() + + x_model = None + o_model = None + if args.x_player == "checkpoint": + x_model = load_player_model(game, args.x_checkpoint, device) + if args.o_player == "checkpoint": + o_model = load_player_model(game, args.o_checkpoint, device) + + results = {1: 0, -1: 0, 0: 0} + for _ in range(args.games): + winner = play_match( + game, + args.x_player, + x_model, + args.o_player, + o_model, + args.num_simulations, + ) + results[winner] += 1 + + print(f"Games: {args.games}") + print(f"X ({args.x_player}) wins: {results[1]}") + print(f"O ({args.o_player}) wins: {results[-1]}") + print(f"Draws: {results[0]}") + + +def play_command(args): + device = get_device(args.device) + game = UltimateTicTacToe() + model = build_model(game, device) + load_checkpoint(model, args.checkpoint, device) + + state = game.get_init_board() + player = 1 + human_player = args.human_player + + while True: + print() + print(format_board(state[0])) + print(f"Turn: {'X' if player == 1 else 'O'}") + print(f"Active small board: {state[1]}") + + reward = game.get_reward_for_player(state, player) + if reward is not None: + if reward == 0: + print("Result: draw") + else: + winner = player if reward == 1 else -player + print(f"Winner: {'X' if winner == 1 else 'O'}") + return + + valid_moves = game.get_valid_moves(state) + legal_actions = [index for index, allowed in enumerate(valid_moves) if allowed] + print(f"Legal moves: {legal_actions}") + + if player == human_player: + while True: + try: + action = parse_action(input("Your move (index or 'row col'): ")) + next_state = game.get_next_state(state, player, action, verify_move=True) + if next_state is False: + raise ValueError(f"Illegal move: {action}") + state, player = next_state + break + except ValueError as exc: + print(exc) + else: + action = ai_action(game, model, state, player, args.num_simulations) + print(f"AI move: {action}") + state, player = game.get_next_state(state, player, action) + + +def build_parser(): + parser = argparse.ArgumentParser(description="Ultimate Tic-Tac-Toe Runner") + subparsers = parser.add_subparsers(dest="command", required=True) + + train_parser = subparsers.add_parser("train", help="Train the model with self-play") + train_parser.add_argument("--device") + train_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"]) + train_parser.add_argument("--resume", action="store_true") + train_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"]) + train_parser.add_argument("--num-iters", type=int, default=DEFAULT_ARGS["numIters"]) + train_parser.add_argument("--num-eps", type=int, default=DEFAULT_ARGS["numEps"]) + train_parser.add_argument("--epochs", type=int, default=DEFAULT_ARGS["epochs"]) + train_parser.add_argument("--batch-size", type=int, default=DEFAULT_ARGS["batch_size"]) + train_parser.add_argument("--lr", type=float, default=DEFAULT_ARGS["lr"]) + train_parser.add_argument("--weight-decay", type=float, default=DEFAULT_ARGS["weight_decay"]) + train_parser.add_argument("--replay-buffer-size", type=int, default=DEFAULT_ARGS["replay_buffer_size"]) + train_parser.add_argument("--value-loss-weight", type=float, default=DEFAULT_ARGS["value_loss_weight"]) + train_parser.add_argument("--grad-clip-norm", type=float, default=DEFAULT_ARGS["grad_clip_norm"]) + train_parser.add_argument("--temperature-threshold", type=int, default=DEFAULT_ARGS["temperature_threshold"]) + train_parser.add_argument("--root-dirichlet-alpha", type=float, default=DEFAULT_ARGS["root_dirichlet_alpha"]) + train_parser.add_argument("--root-exploration-fraction", type=float, default=DEFAULT_ARGS["root_exploration_fraction"]) + train_parser.add_argument("--arena-compare-games", type=int, default=DEFAULT_ARGS["arena_compare_games"]) + train_parser.add_argument("--arena-accept-threshold", type=float, default=DEFAULT_ARGS["arena_accept_threshold"]) + train_parser.add_argument("--arena-compare-simulations", type=int, default=DEFAULT_ARGS["arena_compare_simulations"]) + train_parser.set_defaults(func=train_command) + + eval_parser = subparsers.add_parser("eval", help="Inspect a checkpoint on a position") + eval_parser.add_argument("--device") + eval_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"]) + eval_parser.add_argument("--moves", default="", help="Comma-separated move sequence") + eval_parser.add_argument("--top-k", type=int, default=10) + eval_parser.add_argument("--with-mcts", action="store_true") + eval_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"]) + eval_parser.set_defaults(func=eval_command) + + play_parser = subparsers.add_parser("play", help="Play against the checkpoint") + play_parser.add_argument("--device") + play_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"]) + play_parser.add_argument("--human-player", type=int, choices=[1, -1], default=1) + play_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"]) + play_parser.set_defaults(func=play_command) + + arena_parser = subparsers.add_parser("arena", help="Run repeated matches between agents") + arena_parser.add_argument("--device") + arena_parser.add_argument("--games", type=int, default=20) + arena_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"]) + arena_parser.add_argument("--x-player", choices=["checkpoint", "random"], default="checkpoint") + arena_parser.add_argument("--o-player", choices=["checkpoint", "random"], default="random") + arena_parser.add_argument("--x-checkpoint", default=DEFAULT_ARGS["checkpoint_path"]) + arena_parser.add_argument("--o-checkpoint", default=DEFAULT_ARGS["checkpoint_path"]) + arena_parser.set_defaults(func=arena_command) + + return parser + + +def main(): + parser = build_parser() + args = parser.parse_args() + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/trainer.py b/trainer.py new file mode 100644 index 0000000..a1c4fe0 --- /dev/null +++ b/trainer.py @@ -0,0 +1,258 @@ +import os +import math +import copy +import numpy as np +from collections import deque +from random import shuffle +from progressbar import ProgressBar, Percentage, Bar, ETA, AdaptiveETA, FormatLabel + + +import torch +import torch.optim as optim + +from mcts import MCTS + +loss = {'ploss': float('inf'), 'pkl': float('inf'), 'vloss': float('inf'), 'current': 0, 'max': 0} + +class LossWidget: + def __call__(self, progress, data=None): + return ( + f" {loss['current']}/{loss['max']} " + f"P.CE: {loss['ploss']:.4f}, P.KL: {loss['pkl']:.4f}, V.Loss: {loss['vloss']:.4f}" + ) + +class Trainer: + + def __init__(self, game, model, args): + self.game = game + self.model = model + self.args = args + self.mcts = MCTS(self.game, self.model, self.args) + self.optimizer = optim.Adam( + self.model.parameters(), + lr=self.args.get('lr', 5e-4), + weight_decay=self.args.get('weight_decay', 1e-4), + ) + self.replay_buffer = deque(maxlen=self.args.get('replay_buffer_size', 50000)) + + def _ai_action(self, model, state, player): + board_data, active_board = state + canonical_state = (self.game.get_canonical_board_data(board_data, player), active_board) + mcts_args = dict(self.args) + mcts_args.update( + { + 'num_simulations': self.args.get('arena_compare_simulations', self.args['num_simulations']), + 'root_dirichlet_alpha': None, + 'root_exploration_fraction': None, + } + ) + root = MCTS(self.game, model, mcts_args).run(model, canonical_state, to_play=1) + return root.select_action(temperature=0) + + def _play_arena_game(self, x_model, o_model): + state = self.game.get_init_board() + current_player = 1 + + while True: + reward = self.game.get_reward_for_player(state, current_player) + if reward is not None: + if reward == 0: + return 0 + return current_player if reward == 1 else -current_player + + model = x_model if current_player == 1 else o_model + action = self._ai_action(model, state, current_player) + state, current_player = self.game.get_next_state(state, current_player, action) + + def evaluate_candidate(self, candidate_model, reference_model): + games = self.args.get('arena_compare_games', 0) + if games <= 0: + return True, 0.5 + + candidate_points = 0.0 + required_points = self.args.get('arena_accept_threshold', 0.55) * games + candidate_first_games = (games + 1) // 2 + candidate_second_games = games // 2 + games_played = 0 + + for _ in range(candidate_first_games): + winner = self._play_arena_game(candidate_model, reference_model) + if winner == 1: + candidate_points += 1.0 + elif winner == 0: + candidate_points += 0.5 + games_played += 1 + remaining_games = games - games_played + if candidate_points >= required_points: + return True, candidate_points / games + if candidate_points + remaining_games < required_points: + return False, candidate_points / games + + for _ in range(candidate_second_games): + winner = self._play_arena_game(reference_model, candidate_model) + if winner == -1: + candidate_points += 1.0 + elif winner == 0: + candidate_points += 0.5 + games_played += 1 + remaining_games = games - games_played + if candidate_points >= required_points: + return True, candidate_points / games + if candidate_points + remaining_games < required_points: + return False, candidate_points / games + + score = candidate_points / games + return score >= self.args.get('arena_accept_threshold', 0.55), score + + def exceute_episode(self): + + train_examples = [] + current_player = 1 + state = self.game.get_init_board() + episode_step = 0 + + while True: + board_data, state_inner_board = state + cannonical_board_data = self.game.get_canonical_board_data(board_data, current_player) + canonical_board = (cannonical_board_data, state_inner_board) + + self.mcts = MCTS(self.game, self.model, self.args) + root = self.mcts.run(self.model, canonical_board, to_play=1) + + action_probs = np.zeros(self.game.get_action_size(), dtype=np.float32) + for k, v in root.children.items(): + action_probs[k] = v.visit_count + + action_probs = action_probs / np.sum(action_probs) + encoded_state = self.game.encode_state(canonical_board) + train_examples.append((encoded_state, current_player, action_probs)) + + temperature_threshold = self.args.get('temperature_threshold', 10) + temperature = 1 if episode_step < temperature_threshold else 0 + action = root.select_action(temperature=temperature) + state, current_player = self.game.get_next_state(state, current_player, action) + reward = self.game.get_reward_for_player(state, current_player) + episode_step += 1 + + if reward is not None: + ret = [] + for hist_state, hist_current_player, hist_action_probs in train_examples: + # [Board, currentPlayer, actionProbabilities, Reward] + ret.append((hist_state, hist_action_probs, reward * ((-1) ** (hist_current_player != current_player)))) + + return ret + + def learn(self): + widgets = [Percentage(), Bar(), AdaptiveETA(), LossWidget()] + pbar = ProgressBar(max_value=self.args['numIters'], widgets=widgets) + pbar.update(0) + for i in range(1, self.args['numIters'] + 1): + + # print("{}/{}".format(i, self.args['numIters'])) + + train_examples = [] + + for eps in range(self.args['numEps']): + iteration_train_examples = self.exceute_episode() + train_examples.extend(iteration_train_examples) + + self.replay_buffer.extend(train_examples) + training_examples = list(self.replay_buffer) + shuffle(training_examples) + shuffle(train_examples) + + reference_model = copy.deepcopy(self.model) + reference_model.eval() + reference_state_dict = copy.deepcopy(self.model.state_dict()) + reference_optimizer_state = copy.deepcopy(self.optimizer.state_dict()) + + results = self.train(training_examples) + accepted, arena_score = self.evaluate_candidate(self.model, reference_model) + if accepted: + filename = self.args['checkpoint_path'] + self.save_checkpoint(folder=".", filename=filename) + else: + self.model.load_state_dict(reference_state_dict) + self.optimizer.load_state_dict(reference_optimizer_state) + + # print((float(results[0]), float(results[1]))) + loss['ploss'] = float(results[0]) + loss['pkl'] = float(results[1]) + loss['vloss'] = float(results[2]) + loss['current'] = i + loss['max'] = self.args['numIters'] + # loss_widget.update(float(results[0]), float(results[1])) + pbar.update(i) + pbar.finish() + + def train(self, examples): + pi_losses = [] + pi_kls = [] + v_losses = [] + device = self.model.device + + for epoch in range(self.args['epochs']): + self.model.train() + shuffled_indices = np.random.permutation(len(examples)) + + for batch_start in range(0, len(examples), self.args['batch_size']): + sample_ids = shuffled_indices[batch_start:batch_start + self.args['batch_size']] + boards, pis, vs = list(zip(*[examples[i] for i in sample_ids])) + boards = torch.FloatTensor(np.array(boards).astype(np.float32)) + target_pis = torch.FloatTensor(np.array(pis)) + target_vs = torch.FloatTensor(np.array(vs).astype(np.float32)) + + # predict + boards = boards.contiguous().to(device) + target_pis = target_pis.contiguous().to(device) + target_vs = target_vs.contiguous().to(device) + + # compute output + out_pi, out_v = self.model(boards) + l_pi = self.loss_pi(target_pis, out_pi) + l_pi_kl = self.loss_pi_kl(target_pis, out_pi) + l_v = self.loss_v(target_vs, out_v) + total_loss = l_pi_kl + self.args.get('value_loss_weight', 1.0) * l_v + + pi_losses.append(float(l_pi.detach())) + pi_kls.append(float(l_pi_kl.detach())) + v_losses.append(float(l_v.detach())) + + self.optimizer.zero_grad() + total_loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.get('grad_clip_norm', 5.0)) + self.optimizer.step() + + # print() + # print("Policy Loss", np.mean(pi_losses)) + # print("Value Loss", np.mean(v_losses)) + + return (np.mean(pi_losses), np.mean(pi_kls), np.mean(v_losses)) + # print("Examples:") + # print(out_pi[0].detach()) + # print(target_pis[0]) + + def loss_pi(self, targets, outputs): + loss = -(targets * torch.log(outputs.clamp_min(1e-8))).sum(dim=1) + return loss.mean() + + def loss_pi_kl(self, targets, outputs): + target_log = torch.log(targets.clamp_min(1e-8)) + output_log = torch.log(outputs.clamp_min(1e-8)) + loss = (targets * (target_log - output_log)).sum(dim=1) + return loss.mean() + + def loss_v(self, targets, outputs): + loss = torch.sum((targets-outputs.view(-1))**2)/targets.size()[0] + return loss + + def save_checkpoint(self, folder, filename): + if not os.path.exists(folder): + os.mkdir(folder) + + filepath = os.path.join(folder, filename) + torch.save({ + 'state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'args': self.args, + }, filepath)