import argparse from pathlib import Path import numpy as np import torch from game import UltimateTicTacToe from mcts import MCTS from model import UltimateTicTacToeModel from trainer import Trainer DEFAULT_ARGS = { "num_simulations": 100, "numIters": 50, "numEps": 20, "epochs": 5, "batch_size": 64, "lr": 5e-4, "weight_decay": 1e-4, "replay_buffer_size": 50000, "value_loss_weight": 1.0, "grad_clip_norm": 5.0, "checkpoint_path": "latest.pth", "temperature_threshold": 10, "root_dirichlet_alpha": 0.3, "root_exploration_fraction": 0.25, "arena_compare_games": 6, "arena_accept_threshold": 0.55, "arena_compare_simulations": 8, } def get_device(device_arg): if device_arg: return device_arg return "cuda" if torch.cuda.is_available() else "cpu" def build_model(game, device): return UltimateTicTacToeModel( game.get_board_size(), game.get_action_size(), device, ) def load_checkpoint(model, checkpoint_path, device, optimizer=None, required=True): checkpoint = Path(checkpoint_path) if not checkpoint.exists(): if required: raise FileNotFoundError(f"Checkpoint not found: {checkpoint}") return False state = torch.load(checkpoint, map_location=device) model.load_state_dict(state["state_dict"]) if optimizer is not None and "optimizer_state_dict" in state: optimizer.load_state_dict(state["optimizer_state_dict"]) model.eval() return True def canonical_state(game, state, player): board_data, active_board = state return (game.get_canonical_board_data(board_data, player), active_board) def apply_moves(game, moves): state = game.get_init_board() player = 1 for action in moves: next_state = game.get_next_state(state, player, action, verify_move=True) if next_state is False: raise ValueError(f"Illegal move in sequence: {action}") state, player = next_state return state, player def format_board(board_data): symbols = {1: "X", -1: "O", 0: "."} rows = [] for row in range(9): cells = [symbols[int(board_data[row * 9 + col])] for col in range(9)] groups = [" ".join(cells[idx:idx + 3]) for idx in (0, 3, 6)] rows.append(" | ".join(groups)) if row in (2, 5): rows.append("-" * 23) return "\n".join(rows) def top_policy_moves(policy, limit): ranked = np.argsort(policy)[::-1][:limit] return [(int(action), float(policy[action])) for action in ranked] def parse_moves(text): if not text: return [] return [int(part.strip()) for part in text.split(",") if part.strip()] def parse_action(text): raw = text.strip().replace(",", " ").split() if len(raw) == 1: action = int(raw[0]) elif len(raw) == 2: row, col = (int(value) for value in raw) if not (0 <= row < 9 and 0 <= col < 9): raise ValueError("Row and column must be in [0, 8].") action = row * 9 + col else: raise ValueError("Enter either a flat move index or 'row col'.") if not (0 <= action < 81): raise ValueError("Move index must be in [0, 80].") return action def scalar_value(value): return float(np.asarray(value).reshape(-1)[0]) def train_command(args): device = get_device(args.device) game = UltimateTicTacToe() model = build_model(game, device) train_args = dict(DEFAULT_ARGS) train_args.update( { "num_simulations": args.num_simulations, "numIters": args.num_iters, "numEps": args.num_eps, "epochs": args.epochs, "batch_size": args.batch_size, "lr": args.lr, "weight_decay": args.weight_decay, "replay_buffer_size": args.replay_buffer_size, "value_loss_weight": args.value_loss_weight, "grad_clip_norm": args.grad_clip_norm, "checkpoint_path": args.checkpoint, "temperature_threshold": args.temperature_threshold, "root_dirichlet_alpha": args.root_dirichlet_alpha, "root_exploration_fraction": args.root_exploration_fraction, "arena_compare_games": args.arena_compare_games, "arena_accept_threshold": args.arena_accept_threshold, "arena_compare_simulations": args.arena_compare_simulations, } ) trainer = Trainer(game, model, train_args) if args.resume: load_checkpoint(model, args.checkpoint, device, optimizer=trainer.optimizer) trainer.learn() def eval_command(args): device = get_device(args.device) game = UltimateTicTacToe() model = build_model(game, device) load_checkpoint(model, args.checkpoint, device) moves = parse_moves(args.moves) state, player = apply_moves(game, moves) current_state = canonical_state(game, state, player) encoded = game.encode_state(current_state) policy, value = model.predict(encoded) legal_mask = np.array(game.get_valid_moves(state), dtype=np.float32) policy = policy * legal_mask if policy.sum() > 0: policy = policy / policy.sum() print("Board:") print(format_board(state[0])) print() print(f"Side to move: {'X' if player == 1 else 'O'}") print(f"Active small board: {state[1]}") print(f"Model value: {scalar_value(value):.4f}") print("Top policy moves:") for action, prob in top_policy_moves(policy, args.top_k): print(f" {action:2d} -> {prob:.4f}") if args.with_mcts: mcts_args = dict(DEFAULT_ARGS) mcts_args.update( { "num_simulations": args.num_simulations, "root_dirichlet_alpha": None, "root_exploration_fraction": None, } ) root = MCTS(game, model, mcts_args).run(model, current_state, to_play=1) action = root.select_action(temperature=0) print(f"MCTS best move: {action}") def ai_action(game, model, state, player, num_simulations): current_state = canonical_state(game, state, player) mcts_args = dict(DEFAULT_ARGS) mcts_args.update( { "num_simulations": num_simulations, "root_dirichlet_alpha": None, "root_exploration_fraction": None, } ) root = MCTS(game, model, mcts_args).run(model, current_state, to_play=1) return root.select_action(temperature=0) def random_action(game, state): legal_actions = [index for index, allowed in enumerate(game.get_valid_moves(state)) if allowed] if not legal_actions: raise ValueError("No legal actions available.") return int(np.random.choice(legal_actions)) def load_player_model(game, checkpoint, device): model = build_model(game, device) load_checkpoint(model, checkpoint, device) return model def choose_action(game, player_kind, model, state, player, num_simulations): if player_kind == "random": return random_action(game, state) return ai_action(game, model, state, player, num_simulations) def play_match(game, x_kind, x_model, o_kind, o_model, num_simulations): state = game.get_init_board() player = 1 while True: reward = game.get_reward_for_player(state, player) if reward is not None: if reward == 0: return 0 return player if reward == 1 else -player if player == 1: action = choose_action(game, x_kind, x_model, state, player, num_simulations) else: action = choose_action(game, o_kind, o_model, state, player, num_simulations) state, player = game.get_next_state(state, player, action) def arena_command(args): device = get_device(args.device) game = UltimateTicTacToe() x_model = None o_model = None if args.x_player == "checkpoint": x_model = load_player_model(game, args.x_checkpoint, device) if args.o_player == "checkpoint": o_model = load_player_model(game, args.o_checkpoint, device) results = {1: 0, -1: 0, 0: 0} for _ in range(args.games): winner = play_match( game, args.x_player, x_model, args.o_player, o_model, args.num_simulations, ) results[winner] += 1 print(f"Games: {args.games}") print(f"X ({args.x_player}) wins: {results[1]}") print(f"O ({args.o_player}) wins: {results[-1]}") print(f"Draws: {results[0]}") def play_command(args): device = get_device(args.device) game = UltimateTicTacToe() model = build_model(game, device) load_checkpoint(model, args.checkpoint, device) state = game.get_init_board() player = 1 human_player = args.human_player while True: print() print(format_board(state[0])) print(f"Turn: {'X' if player == 1 else 'O'}") print(f"Active small board: {state[1]}") reward = game.get_reward_for_player(state, player) if reward is not None: if reward == 0: print("Result: draw") else: winner = player if reward == 1 else -player print(f"Winner: {'X' if winner == 1 else 'O'}") return valid_moves = game.get_valid_moves(state) legal_actions = [index for index, allowed in enumerate(valid_moves) if allowed] print(f"Legal moves: {legal_actions}") if player == human_player: while True: try: action = parse_action(input("Your move (index or 'row col'): ")) next_state = game.get_next_state(state, player, action, verify_move=True) if next_state is False: raise ValueError(f"Illegal move: {action}") state, player = next_state break except ValueError as exc: print(exc) else: action = ai_action(game, model, state, player, args.num_simulations) print(f"AI move: {action}") state, player = game.get_next_state(state, player, action) def build_parser(): parser = argparse.ArgumentParser(description="Ultimate Tic-Tac-Toe Runner") subparsers = parser.add_subparsers(dest="command", required=True) train_parser = subparsers.add_parser("train", help="Train the model with self-play") train_parser.add_argument("--device") train_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"]) train_parser.add_argument("--resume", action="store_true") train_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"]) train_parser.add_argument("--num-iters", type=int, default=DEFAULT_ARGS["numIters"]) train_parser.add_argument("--num-eps", type=int, default=DEFAULT_ARGS["numEps"]) train_parser.add_argument("--epochs", type=int, default=DEFAULT_ARGS["epochs"]) train_parser.add_argument("--batch-size", type=int, default=DEFAULT_ARGS["batch_size"]) train_parser.add_argument("--lr", type=float, default=DEFAULT_ARGS["lr"]) train_parser.add_argument("--weight-decay", type=float, default=DEFAULT_ARGS["weight_decay"]) train_parser.add_argument("--replay-buffer-size", type=int, default=DEFAULT_ARGS["replay_buffer_size"]) train_parser.add_argument("--value-loss-weight", type=float, default=DEFAULT_ARGS["value_loss_weight"]) train_parser.add_argument("--grad-clip-norm", type=float, default=DEFAULT_ARGS["grad_clip_norm"]) train_parser.add_argument("--temperature-threshold", type=int, default=DEFAULT_ARGS["temperature_threshold"]) train_parser.add_argument("--root-dirichlet-alpha", type=float, default=DEFAULT_ARGS["root_dirichlet_alpha"]) train_parser.add_argument("--root-exploration-fraction", type=float, default=DEFAULT_ARGS["root_exploration_fraction"]) train_parser.add_argument("--arena-compare-games", type=int, default=DEFAULT_ARGS["arena_compare_games"]) train_parser.add_argument("--arena-accept-threshold", type=float, default=DEFAULT_ARGS["arena_accept_threshold"]) train_parser.add_argument("--arena-compare-simulations", type=int, default=DEFAULT_ARGS["arena_compare_simulations"]) train_parser.set_defaults(func=train_command) eval_parser = subparsers.add_parser("eval", help="Inspect a checkpoint on a position") eval_parser.add_argument("--device") eval_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"]) eval_parser.add_argument("--moves", default="", help="Comma-separated move sequence") eval_parser.add_argument("--top-k", type=int, default=10) eval_parser.add_argument("--with-mcts", action="store_true") eval_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"]) eval_parser.set_defaults(func=eval_command) play_parser = subparsers.add_parser("play", help="Play against the checkpoint") play_parser.add_argument("--device") play_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"]) play_parser.add_argument("--human-player", type=int, choices=[1, -1], default=1) play_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"]) play_parser.set_defaults(func=play_command) arena_parser = subparsers.add_parser("arena", help="Run repeated matches between agents") arena_parser.add_argument("--device") arena_parser.add_argument("--games", type=int, default=20) arena_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"]) arena_parser.add_argument("--x-player", choices=["checkpoint", "random"], default="checkpoint") arena_parser.add_argument("--o-player", choices=["checkpoint", "random"], default="random") arena_parser.add_argument("--x-checkpoint", default=DEFAULT_ARGS["checkpoint_path"]) arena_parser.add_argument("--o-checkpoint", default=DEFAULT_ARGS["checkpoint_path"]) arena_parser.set_defaults(func=arena_command) return parser def main(): parser = build_parser() args = parser.parse_args() args.func(args) if __name__ == "__main__": main()