First commit

This commit is contained in:
2026-03-23 21:19:29 +01:00
commit 29fc731e6c
7 changed files with 1173 additions and 0 deletions

384
run.py Normal file
View File

@@ -0,0 +1,384 @@
import argparse
from pathlib import Path
import numpy as np
import torch
from game import UltimateTicTacToe
from mcts import MCTS
from model import UltimateTicTacToeModel
from trainer import Trainer
DEFAULT_ARGS = {
"num_simulations": 100,
"numIters": 50,
"numEps": 20,
"epochs": 5,
"batch_size": 64,
"lr": 5e-4,
"weight_decay": 1e-4,
"replay_buffer_size": 50000,
"value_loss_weight": 1.0,
"grad_clip_norm": 5.0,
"checkpoint_path": "latest.pth",
"temperature_threshold": 10,
"root_dirichlet_alpha": 0.3,
"root_exploration_fraction": 0.25,
"arena_compare_games": 6,
"arena_accept_threshold": 0.55,
"arena_compare_simulations": 8,
}
def get_device(device_arg):
if device_arg:
return device_arg
return "cuda" if torch.cuda.is_available() else "cpu"
def build_model(game, device):
return UltimateTicTacToeModel(
game.get_board_size(),
game.get_action_size(),
device,
)
def load_checkpoint(model, checkpoint_path, device, optimizer=None, required=True):
checkpoint = Path(checkpoint_path)
if not checkpoint.exists():
if required:
raise FileNotFoundError(f"Checkpoint not found: {checkpoint}")
return False
state = torch.load(checkpoint, map_location=device)
model.load_state_dict(state["state_dict"])
if optimizer is not None and "optimizer_state_dict" in state:
optimizer.load_state_dict(state["optimizer_state_dict"])
model.eval()
return True
def canonical_state(game, state, player):
board_data, active_board = state
return (game.get_canonical_board_data(board_data, player), active_board)
def apply_moves(game, moves):
state = game.get_init_board()
player = 1
for action in moves:
next_state = game.get_next_state(state, player, action, verify_move=True)
if next_state is False:
raise ValueError(f"Illegal move in sequence: {action}")
state, player = next_state
return state, player
def format_board(board_data):
symbols = {1: "X", -1: "O", 0: "."}
rows = []
for row in range(9):
cells = [symbols[int(board_data[row * 9 + col])] for col in range(9)]
groups = [" ".join(cells[idx:idx + 3]) for idx in (0, 3, 6)]
rows.append(" | ".join(groups))
if row in (2, 5):
rows.append("-" * 23)
return "\n".join(rows)
def top_policy_moves(policy, limit):
ranked = np.argsort(policy)[::-1][:limit]
return [(int(action), float(policy[action])) for action in ranked]
def parse_moves(text):
if not text:
return []
return [int(part.strip()) for part in text.split(",") if part.strip()]
def parse_action(text):
raw = text.strip().replace(",", " ").split()
if len(raw) == 1:
action = int(raw[0])
elif len(raw) == 2:
row, col = (int(value) for value in raw)
if not (0 <= row < 9 and 0 <= col < 9):
raise ValueError("Row and column must be in [0, 8].")
action = row * 9 + col
else:
raise ValueError("Enter either a flat move index or 'row col'.")
if not (0 <= action < 81):
raise ValueError("Move index must be in [0, 80].")
return action
def scalar_value(value):
return float(np.asarray(value).reshape(-1)[0])
def train_command(args):
device = get_device(args.device)
game = UltimateTicTacToe()
model = build_model(game, device)
train_args = dict(DEFAULT_ARGS)
train_args.update(
{
"num_simulations": args.num_simulations,
"numIters": args.num_iters,
"numEps": args.num_eps,
"epochs": args.epochs,
"batch_size": args.batch_size,
"lr": args.lr,
"weight_decay": args.weight_decay,
"replay_buffer_size": args.replay_buffer_size,
"value_loss_weight": args.value_loss_weight,
"grad_clip_norm": args.grad_clip_norm,
"checkpoint_path": args.checkpoint,
"temperature_threshold": args.temperature_threshold,
"root_dirichlet_alpha": args.root_dirichlet_alpha,
"root_exploration_fraction": args.root_exploration_fraction,
"arena_compare_games": args.arena_compare_games,
"arena_accept_threshold": args.arena_accept_threshold,
"arena_compare_simulations": args.arena_compare_simulations,
}
)
trainer = Trainer(game, model, train_args)
if args.resume:
load_checkpoint(model, args.checkpoint, device, optimizer=trainer.optimizer)
trainer.learn()
def eval_command(args):
device = get_device(args.device)
game = UltimateTicTacToe()
model = build_model(game, device)
load_checkpoint(model, args.checkpoint, device)
moves = parse_moves(args.moves)
state, player = apply_moves(game, moves)
current_state = canonical_state(game, state, player)
encoded = game.encode_state(current_state)
policy, value = model.predict(encoded)
legal_mask = np.array(game.get_valid_moves(state), dtype=np.float32)
policy = policy * legal_mask
if policy.sum() > 0:
policy = policy / policy.sum()
print("Board:")
print(format_board(state[0]))
print()
print(f"Side to move: {'X' if player == 1 else 'O'}")
print(f"Active small board: {state[1]}")
print(f"Model value: {scalar_value(value):.4f}")
print("Top policy moves:")
for action, prob in top_policy_moves(policy, args.top_k):
print(f" {action:2d} -> {prob:.4f}")
if args.with_mcts:
mcts_args = dict(DEFAULT_ARGS)
mcts_args.update(
{
"num_simulations": args.num_simulations,
"root_dirichlet_alpha": None,
"root_exploration_fraction": None,
}
)
root = MCTS(game, model, mcts_args).run(model, current_state, to_play=1)
action = root.select_action(temperature=0)
print(f"MCTS best move: {action}")
def ai_action(game, model, state, player, num_simulations):
current_state = canonical_state(game, state, player)
mcts_args = dict(DEFAULT_ARGS)
mcts_args.update(
{
"num_simulations": num_simulations,
"root_dirichlet_alpha": None,
"root_exploration_fraction": None,
}
)
root = MCTS(game, model, mcts_args).run(model, current_state, to_play=1)
return root.select_action(temperature=0)
def random_action(game, state):
legal_actions = [index for index, allowed in enumerate(game.get_valid_moves(state)) if allowed]
if not legal_actions:
raise ValueError("No legal actions available.")
return int(np.random.choice(legal_actions))
def load_player_model(game, checkpoint, device):
model = build_model(game, device)
load_checkpoint(model, checkpoint, device)
return model
def choose_action(game, player_kind, model, state, player, num_simulations):
if player_kind == "random":
return random_action(game, state)
return ai_action(game, model, state, player, num_simulations)
def play_match(game, x_kind, x_model, o_kind, o_model, num_simulations):
state = game.get_init_board()
player = 1
while True:
reward = game.get_reward_for_player(state, player)
if reward is not None:
if reward == 0:
return 0
return player if reward == 1 else -player
if player == 1:
action = choose_action(game, x_kind, x_model, state, player, num_simulations)
else:
action = choose_action(game, o_kind, o_model, state, player, num_simulations)
state, player = game.get_next_state(state, player, action)
def arena_command(args):
device = get_device(args.device)
game = UltimateTicTacToe()
x_model = None
o_model = None
if args.x_player == "checkpoint":
x_model = load_player_model(game, args.x_checkpoint, device)
if args.o_player == "checkpoint":
o_model = load_player_model(game, args.o_checkpoint, device)
results = {1: 0, -1: 0, 0: 0}
for _ in range(args.games):
winner = play_match(
game,
args.x_player,
x_model,
args.o_player,
o_model,
args.num_simulations,
)
results[winner] += 1
print(f"Games: {args.games}")
print(f"X ({args.x_player}) wins: {results[1]}")
print(f"O ({args.o_player}) wins: {results[-1]}")
print(f"Draws: {results[0]}")
def play_command(args):
device = get_device(args.device)
game = UltimateTicTacToe()
model = build_model(game, device)
load_checkpoint(model, args.checkpoint, device)
state = game.get_init_board()
player = 1
human_player = args.human_player
while True:
print()
print(format_board(state[0]))
print(f"Turn: {'X' if player == 1 else 'O'}")
print(f"Active small board: {state[1]}")
reward = game.get_reward_for_player(state, player)
if reward is not None:
if reward == 0:
print("Result: draw")
else:
winner = player if reward == 1 else -player
print(f"Winner: {'X' if winner == 1 else 'O'}")
return
valid_moves = game.get_valid_moves(state)
legal_actions = [index for index, allowed in enumerate(valid_moves) if allowed]
print(f"Legal moves: {legal_actions}")
if player == human_player:
while True:
try:
action = parse_action(input("Your move (index or 'row col'): "))
next_state = game.get_next_state(state, player, action, verify_move=True)
if next_state is False:
raise ValueError(f"Illegal move: {action}")
state, player = next_state
break
except ValueError as exc:
print(exc)
else:
action = ai_action(game, model, state, player, args.num_simulations)
print(f"AI move: {action}")
state, player = game.get_next_state(state, player, action)
def build_parser():
parser = argparse.ArgumentParser(description="Ultimate Tic-Tac-Toe Runner")
subparsers = parser.add_subparsers(dest="command", required=True)
train_parser = subparsers.add_parser("train", help="Train the model with self-play")
train_parser.add_argument("--device")
train_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
train_parser.add_argument("--resume", action="store_true")
train_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
train_parser.add_argument("--num-iters", type=int, default=DEFAULT_ARGS["numIters"])
train_parser.add_argument("--num-eps", type=int, default=DEFAULT_ARGS["numEps"])
train_parser.add_argument("--epochs", type=int, default=DEFAULT_ARGS["epochs"])
train_parser.add_argument("--batch-size", type=int, default=DEFAULT_ARGS["batch_size"])
train_parser.add_argument("--lr", type=float, default=DEFAULT_ARGS["lr"])
train_parser.add_argument("--weight-decay", type=float, default=DEFAULT_ARGS["weight_decay"])
train_parser.add_argument("--replay-buffer-size", type=int, default=DEFAULT_ARGS["replay_buffer_size"])
train_parser.add_argument("--value-loss-weight", type=float, default=DEFAULT_ARGS["value_loss_weight"])
train_parser.add_argument("--grad-clip-norm", type=float, default=DEFAULT_ARGS["grad_clip_norm"])
train_parser.add_argument("--temperature-threshold", type=int, default=DEFAULT_ARGS["temperature_threshold"])
train_parser.add_argument("--root-dirichlet-alpha", type=float, default=DEFAULT_ARGS["root_dirichlet_alpha"])
train_parser.add_argument("--root-exploration-fraction", type=float, default=DEFAULT_ARGS["root_exploration_fraction"])
train_parser.add_argument("--arena-compare-games", type=int, default=DEFAULT_ARGS["arena_compare_games"])
train_parser.add_argument("--arena-accept-threshold", type=float, default=DEFAULT_ARGS["arena_accept_threshold"])
train_parser.add_argument("--arena-compare-simulations", type=int, default=DEFAULT_ARGS["arena_compare_simulations"])
train_parser.set_defaults(func=train_command)
eval_parser = subparsers.add_parser("eval", help="Inspect a checkpoint on a position")
eval_parser.add_argument("--device")
eval_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
eval_parser.add_argument("--moves", default="", help="Comma-separated move sequence")
eval_parser.add_argument("--top-k", type=int, default=10)
eval_parser.add_argument("--with-mcts", action="store_true")
eval_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
eval_parser.set_defaults(func=eval_command)
play_parser = subparsers.add_parser("play", help="Play against the checkpoint")
play_parser.add_argument("--device")
play_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
play_parser.add_argument("--human-player", type=int, choices=[1, -1], default=1)
play_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
play_parser.set_defaults(func=play_command)
arena_parser = subparsers.add_parser("arena", help="Run repeated matches between agents")
arena_parser.add_argument("--device")
arena_parser.add_argument("--games", type=int, default=20)
arena_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
arena_parser.add_argument("--x-player", choices=["checkpoint", "random"], default="checkpoint")
arena_parser.add_argument("--o-player", choices=["checkpoint", "random"], default="random")
arena_parser.add_argument("--x-checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
arena_parser.add_argument("--o-checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
arena_parser.set_defaults(func=arena_command)
return parser
def main():
parser = build_parser()
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
main()