First commit
This commit is contained in:
384
run.py
Normal file
384
run.py
Normal file
@@ -0,0 +1,384 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from game import UltimateTicTacToe
|
||||
from mcts import MCTS
|
||||
from model import UltimateTicTacToeModel
|
||||
from trainer import Trainer
|
||||
|
||||
|
||||
DEFAULT_ARGS = {
|
||||
"num_simulations": 100,
|
||||
"numIters": 50,
|
||||
"numEps": 20,
|
||||
"epochs": 5,
|
||||
"batch_size": 64,
|
||||
"lr": 5e-4,
|
||||
"weight_decay": 1e-4,
|
||||
"replay_buffer_size": 50000,
|
||||
"value_loss_weight": 1.0,
|
||||
"grad_clip_norm": 5.0,
|
||||
"checkpoint_path": "latest.pth",
|
||||
"temperature_threshold": 10,
|
||||
"root_dirichlet_alpha": 0.3,
|
||||
"root_exploration_fraction": 0.25,
|
||||
"arena_compare_games": 6,
|
||||
"arena_accept_threshold": 0.55,
|
||||
"arena_compare_simulations": 8,
|
||||
}
|
||||
|
||||
|
||||
def get_device(device_arg):
|
||||
if device_arg:
|
||||
return device_arg
|
||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def build_model(game, device):
|
||||
return UltimateTicTacToeModel(
|
||||
game.get_board_size(),
|
||||
game.get_action_size(),
|
||||
device,
|
||||
)
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, device, optimizer=None, required=True):
|
||||
checkpoint = Path(checkpoint_path)
|
||||
if not checkpoint.exists():
|
||||
if required:
|
||||
raise FileNotFoundError(f"Checkpoint not found: {checkpoint}")
|
||||
return False
|
||||
|
||||
state = torch.load(checkpoint, map_location=device)
|
||||
model.load_state_dict(state["state_dict"])
|
||||
if optimizer is not None and "optimizer_state_dict" in state:
|
||||
optimizer.load_state_dict(state["optimizer_state_dict"])
|
||||
model.eval()
|
||||
return True
|
||||
|
||||
|
||||
def canonical_state(game, state, player):
|
||||
board_data, active_board = state
|
||||
return (game.get_canonical_board_data(board_data, player), active_board)
|
||||
|
||||
|
||||
def apply_moves(game, moves):
|
||||
state = game.get_init_board()
|
||||
player = 1
|
||||
for action in moves:
|
||||
next_state = game.get_next_state(state, player, action, verify_move=True)
|
||||
if next_state is False:
|
||||
raise ValueError(f"Illegal move in sequence: {action}")
|
||||
state, player = next_state
|
||||
return state, player
|
||||
|
||||
|
||||
def format_board(board_data):
|
||||
symbols = {1: "X", -1: "O", 0: "."}
|
||||
rows = []
|
||||
for row in range(9):
|
||||
cells = [symbols[int(board_data[row * 9 + col])] for col in range(9)]
|
||||
groups = [" ".join(cells[idx:idx + 3]) for idx in (0, 3, 6)]
|
||||
rows.append(" | ".join(groups))
|
||||
if row in (2, 5):
|
||||
rows.append("-" * 23)
|
||||
return "\n".join(rows)
|
||||
|
||||
|
||||
def top_policy_moves(policy, limit):
|
||||
ranked = np.argsort(policy)[::-1][:limit]
|
||||
return [(int(action), float(policy[action])) for action in ranked]
|
||||
|
||||
|
||||
def parse_moves(text):
|
||||
if not text:
|
||||
return []
|
||||
return [int(part.strip()) for part in text.split(",") if part.strip()]
|
||||
|
||||
|
||||
def parse_action(text):
|
||||
raw = text.strip().replace(",", " ").split()
|
||||
if len(raw) == 1:
|
||||
action = int(raw[0])
|
||||
elif len(raw) == 2:
|
||||
row, col = (int(value) for value in raw)
|
||||
if not (0 <= row < 9 and 0 <= col < 9):
|
||||
raise ValueError("Row and column must be in [0, 8].")
|
||||
action = row * 9 + col
|
||||
else:
|
||||
raise ValueError("Enter either a flat move index or 'row col'.")
|
||||
if not (0 <= action < 81):
|
||||
raise ValueError("Move index must be in [0, 80].")
|
||||
return action
|
||||
|
||||
|
||||
def scalar_value(value):
|
||||
return float(np.asarray(value).reshape(-1)[0])
|
||||
|
||||
|
||||
def train_command(args):
|
||||
device = get_device(args.device)
|
||||
game = UltimateTicTacToe()
|
||||
model = build_model(game, device)
|
||||
|
||||
train_args = dict(DEFAULT_ARGS)
|
||||
train_args.update(
|
||||
{
|
||||
"num_simulations": args.num_simulations,
|
||||
"numIters": args.num_iters,
|
||||
"numEps": args.num_eps,
|
||||
"epochs": args.epochs,
|
||||
"batch_size": args.batch_size,
|
||||
"lr": args.lr,
|
||||
"weight_decay": args.weight_decay,
|
||||
"replay_buffer_size": args.replay_buffer_size,
|
||||
"value_loss_weight": args.value_loss_weight,
|
||||
"grad_clip_norm": args.grad_clip_norm,
|
||||
"checkpoint_path": args.checkpoint,
|
||||
"temperature_threshold": args.temperature_threshold,
|
||||
"root_dirichlet_alpha": args.root_dirichlet_alpha,
|
||||
"root_exploration_fraction": args.root_exploration_fraction,
|
||||
"arena_compare_games": args.arena_compare_games,
|
||||
"arena_accept_threshold": args.arena_accept_threshold,
|
||||
"arena_compare_simulations": args.arena_compare_simulations,
|
||||
}
|
||||
)
|
||||
|
||||
trainer = Trainer(game, model, train_args)
|
||||
if args.resume:
|
||||
load_checkpoint(model, args.checkpoint, device, optimizer=trainer.optimizer)
|
||||
trainer.learn()
|
||||
|
||||
|
||||
def eval_command(args):
|
||||
device = get_device(args.device)
|
||||
game = UltimateTicTacToe()
|
||||
model = build_model(game, device)
|
||||
load_checkpoint(model, args.checkpoint, device)
|
||||
|
||||
moves = parse_moves(args.moves)
|
||||
state, player = apply_moves(game, moves)
|
||||
current_state = canonical_state(game, state, player)
|
||||
encoded = game.encode_state(current_state)
|
||||
policy, value = model.predict(encoded)
|
||||
legal_mask = np.array(game.get_valid_moves(state), dtype=np.float32)
|
||||
policy = policy * legal_mask
|
||||
if policy.sum() > 0:
|
||||
policy = policy / policy.sum()
|
||||
|
||||
print("Board:")
|
||||
print(format_board(state[0]))
|
||||
print()
|
||||
print(f"Side to move: {'X' if player == 1 else 'O'}")
|
||||
print(f"Active small board: {state[1]}")
|
||||
print(f"Model value: {scalar_value(value):.4f}")
|
||||
print("Top policy moves:")
|
||||
for action, prob in top_policy_moves(policy, args.top_k):
|
||||
print(f" {action:2d} -> {prob:.4f}")
|
||||
|
||||
if args.with_mcts:
|
||||
mcts_args = dict(DEFAULT_ARGS)
|
||||
mcts_args.update(
|
||||
{
|
||||
"num_simulations": args.num_simulations,
|
||||
"root_dirichlet_alpha": None,
|
||||
"root_exploration_fraction": None,
|
||||
}
|
||||
)
|
||||
root = MCTS(game, model, mcts_args).run(model, current_state, to_play=1)
|
||||
action = root.select_action(temperature=0)
|
||||
print(f"MCTS best move: {action}")
|
||||
|
||||
|
||||
def ai_action(game, model, state, player, num_simulations):
|
||||
current_state = canonical_state(game, state, player)
|
||||
mcts_args = dict(DEFAULT_ARGS)
|
||||
mcts_args.update(
|
||||
{
|
||||
"num_simulations": num_simulations,
|
||||
"root_dirichlet_alpha": None,
|
||||
"root_exploration_fraction": None,
|
||||
}
|
||||
)
|
||||
root = MCTS(game, model, mcts_args).run(model, current_state, to_play=1)
|
||||
return root.select_action(temperature=0)
|
||||
|
||||
|
||||
def random_action(game, state):
|
||||
legal_actions = [index for index, allowed in enumerate(game.get_valid_moves(state)) if allowed]
|
||||
if not legal_actions:
|
||||
raise ValueError("No legal actions available.")
|
||||
return int(np.random.choice(legal_actions))
|
||||
|
||||
|
||||
def load_player_model(game, checkpoint, device):
|
||||
model = build_model(game, device)
|
||||
load_checkpoint(model, checkpoint, device)
|
||||
return model
|
||||
|
||||
|
||||
def choose_action(game, player_kind, model, state, player, num_simulations):
|
||||
if player_kind == "random":
|
||||
return random_action(game, state)
|
||||
return ai_action(game, model, state, player, num_simulations)
|
||||
|
||||
|
||||
def play_match(game, x_kind, x_model, o_kind, o_model, num_simulations):
|
||||
state = game.get_init_board()
|
||||
player = 1
|
||||
|
||||
while True:
|
||||
reward = game.get_reward_for_player(state, player)
|
||||
if reward is not None:
|
||||
if reward == 0:
|
||||
return 0
|
||||
return player if reward == 1 else -player
|
||||
|
||||
if player == 1:
|
||||
action = choose_action(game, x_kind, x_model, state, player, num_simulations)
|
||||
else:
|
||||
action = choose_action(game, o_kind, o_model, state, player, num_simulations)
|
||||
state, player = game.get_next_state(state, player, action)
|
||||
|
||||
|
||||
def arena_command(args):
|
||||
device = get_device(args.device)
|
||||
game = UltimateTicTacToe()
|
||||
|
||||
x_model = None
|
||||
o_model = None
|
||||
if args.x_player == "checkpoint":
|
||||
x_model = load_player_model(game, args.x_checkpoint, device)
|
||||
if args.o_player == "checkpoint":
|
||||
o_model = load_player_model(game, args.o_checkpoint, device)
|
||||
|
||||
results = {1: 0, -1: 0, 0: 0}
|
||||
for _ in range(args.games):
|
||||
winner = play_match(
|
||||
game,
|
||||
args.x_player,
|
||||
x_model,
|
||||
args.o_player,
|
||||
o_model,
|
||||
args.num_simulations,
|
||||
)
|
||||
results[winner] += 1
|
||||
|
||||
print(f"Games: {args.games}")
|
||||
print(f"X ({args.x_player}) wins: {results[1]}")
|
||||
print(f"O ({args.o_player}) wins: {results[-1]}")
|
||||
print(f"Draws: {results[0]}")
|
||||
|
||||
|
||||
def play_command(args):
|
||||
device = get_device(args.device)
|
||||
game = UltimateTicTacToe()
|
||||
model = build_model(game, device)
|
||||
load_checkpoint(model, args.checkpoint, device)
|
||||
|
||||
state = game.get_init_board()
|
||||
player = 1
|
||||
human_player = args.human_player
|
||||
|
||||
while True:
|
||||
print()
|
||||
print(format_board(state[0]))
|
||||
print(f"Turn: {'X' if player == 1 else 'O'}")
|
||||
print(f"Active small board: {state[1]}")
|
||||
|
||||
reward = game.get_reward_for_player(state, player)
|
||||
if reward is not None:
|
||||
if reward == 0:
|
||||
print("Result: draw")
|
||||
else:
|
||||
winner = player if reward == 1 else -player
|
||||
print(f"Winner: {'X' if winner == 1 else 'O'}")
|
||||
return
|
||||
|
||||
valid_moves = game.get_valid_moves(state)
|
||||
legal_actions = [index for index, allowed in enumerate(valid_moves) if allowed]
|
||||
print(f"Legal moves: {legal_actions}")
|
||||
|
||||
if player == human_player:
|
||||
while True:
|
||||
try:
|
||||
action = parse_action(input("Your move (index or 'row col'): "))
|
||||
next_state = game.get_next_state(state, player, action, verify_move=True)
|
||||
if next_state is False:
|
||||
raise ValueError(f"Illegal move: {action}")
|
||||
state, player = next_state
|
||||
break
|
||||
except ValueError as exc:
|
||||
print(exc)
|
||||
else:
|
||||
action = ai_action(game, model, state, player, args.num_simulations)
|
||||
print(f"AI move: {action}")
|
||||
state, player = game.get_next_state(state, player, action)
|
||||
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(description="Ultimate Tic-Tac-Toe Runner")
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
train_parser = subparsers.add_parser("train", help="Train the model with self-play")
|
||||
train_parser.add_argument("--device")
|
||||
train_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
|
||||
train_parser.add_argument("--resume", action="store_true")
|
||||
train_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
|
||||
train_parser.add_argument("--num-iters", type=int, default=DEFAULT_ARGS["numIters"])
|
||||
train_parser.add_argument("--num-eps", type=int, default=DEFAULT_ARGS["numEps"])
|
||||
train_parser.add_argument("--epochs", type=int, default=DEFAULT_ARGS["epochs"])
|
||||
train_parser.add_argument("--batch-size", type=int, default=DEFAULT_ARGS["batch_size"])
|
||||
train_parser.add_argument("--lr", type=float, default=DEFAULT_ARGS["lr"])
|
||||
train_parser.add_argument("--weight-decay", type=float, default=DEFAULT_ARGS["weight_decay"])
|
||||
train_parser.add_argument("--replay-buffer-size", type=int, default=DEFAULT_ARGS["replay_buffer_size"])
|
||||
train_parser.add_argument("--value-loss-weight", type=float, default=DEFAULT_ARGS["value_loss_weight"])
|
||||
train_parser.add_argument("--grad-clip-norm", type=float, default=DEFAULT_ARGS["grad_clip_norm"])
|
||||
train_parser.add_argument("--temperature-threshold", type=int, default=DEFAULT_ARGS["temperature_threshold"])
|
||||
train_parser.add_argument("--root-dirichlet-alpha", type=float, default=DEFAULT_ARGS["root_dirichlet_alpha"])
|
||||
train_parser.add_argument("--root-exploration-fraction", type=float, default=DEFAULT_ARGS["root_exploration_fraction"])
|
||||
train_parser.add_argument("--arena-compare-games", type=int, default=DEFAULT_ARGS["arena_compare_games"])
|
||||
train_parser.add_argument("--arena-accept-threshold", type=float, default=DEFAULT_ARGS["arena_accept_threshold"])
|
||||
train_parser.add_argument("--arena-compare-simulations", type=int, default=DEFAULT_ARGS["arena_compare_simulations"])
|
||||
train_parser.set_defaults(func=train_command)
|
||||
|
||||
eval_parser = subparsers.add_parser("eval", help="Inspect a checkpoint on a position")
|
||||
eval_parser.add_argument("--device")
|
||||
eval_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
|
||||
eval_parser.add_argument("--moves", default="", help="Comma-separated move sequence")
|
||||
eval_parser.add_argument("--top-k", type=int, default=10)
|
||||
eval_parser.add_argument("--with-mcts", action="store_true")
|
||||
eval_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
|
||||
eval_parser.set_defaults(func=eval_command)
|
||||
|
||||
play_parser = subparsers.add_parser("play", help="Play against the checkpoint")
|
||||
play_parser.add_argument("--device")
|
||||
play_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
|
||||
play_parser.add_argument("--human-player", type=int, choices=[1, -1], default=1)
|
||||
play_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
|
||||
play_parser.set_defaults(func=play_command)
|
||||
|
||||
arena_parser = subparsers.add_parser("arena", help="Run repeated matches between agents")
|
||||
arena_parser.add_argument("--device")
|
||||
arena_parser.add_argument("--games", type=int, default=20)
|
||||
arena_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
|
||||
arena_parser.add_argument("--x-player", choices=["checkpoint", "random"], default="checkpoint")
|
||||
arena_parser.add_argument("--o-player", choices=["checkpoint", "random"], default="random")
|
||||
arena_parser.add_argument("--x-checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
|
||||
arena_parser.add_argument("--o-checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
|
||||
arena_parser.set_defaults(func=arena_command)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user