First commit

This commit is contained in:
2026-03-23 21:19:29 +01:00
commit 29fc731e6c
7 changed files with 1173 additions and 0 deletions

4
.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
__pycache__/
.ipynb_checkpoints/
*.pth

60
README.md Normal file
View File

@@ -0,0 +1,60 @@
# Ultimate Tic Tac Toe Deep Learning Bot
**Usage**
Run `python run.py --help` for help.
**Shared flags**
- --device: Torch device string. If omitted, it auto-picks "cuda" when available, otherwise "cpu".
- --checkpoint: Path to the model checkpoint file. Default is latest.pth. It is loaded for eval, play, and checkpoint-based arena,
and used as the save path for accepted training checkpoints.
**Training parameters**
- --resume: Loads model and optimizer state from --checkpoint before continuing training.
- --num-simulations default 100: MCTS rollouts per move during self-play. Higher is stronger/slower.
- --num-iters default 50: Number of outer training iterations. Each iteration generates new self-play games, trains, then arena-tests
the new model.
- --num-eps default 20: Self-play games per iteration.
- --epochs default 5: Passes over the current replay-buffer training set per iteration.
- --batch-size default 64: Mini-batch size for gradient updates.
- --lr default 5e-4: Adam learning rate.
- --weight-decay default 1e-4: Adam weight decay (L2-style regularization).
- --replay-buffer-size default 50000: Maximum number of training examples retained across iterations. Older examples are dropped.
- --value-loss-weight default 1.0: Multiplier on the value-head loss in total training loss. Total loss is policy_KL +
value_loss_weight * value_loss.
- --grad-clip-norm default 5.0: Global gradient norm clipping threshold before optimizer step.
- --temperature-threshold default 10: In self-play, moves before this step use stochastic sampling from MCTS visit counts; later
moves use greedy selection.
- --root-dirichlet-alpha default 0.3: Dirichlet noise alpha added to root priors during self-play MCTS to force exploration.
- --root-exploration-fraction default 0.25: How much of that root prior is replaced by Dirichlet noise.
- --arena-compare-games default 6: Number of head-to-head games between candidate and previous model after each iteration. If <= 0,
every candidate is accepted.
- --arena-accept-threshold default 0.55: Minimum average points needed in arena to keep the new model. Win = 1, draw = 0.5.
- --arena-compare-simulations default 8: MCTS simulations per move during those arena comparison games. Separate from self-play
--num-simulations.
**Evaluation parameters**
- --moves default "": Comma-separated move list to reach a position from the starting board, e.g. 0,10,4.
- --top-k default 10: How many highest-probability legal moves to print from the model policy.
- --with-mcts: Also run MCTS on that position and print the best move, instead of only raw network policy/value.
- --num-simulations default 100: Only matters with --with-mcts; controls MCTS search depth for that evaluation.
**Play parameters**
- --human-player default 1: Which side you control. 1 means X, -1 means O.
- --num-simulations default 100: MCTS simulations the AI uses for each move.
**Arena parameters**
- --games default 20: Number of matches to run.
- --num-simulations default 100: MCTS simulations per move for checkpoint-based players.
- --x-player / --o-player: Either checkpoint or random. Chooses the agent type for each side.
- --x-checkpoint / --o-checkpoint: Checkpoint path for that side when its player type is checkpoint. Ignored for random.
A few practical examples:
```bash
python run.py train --num-iters 100 --num-eps 50 --resume
python run.py eval --checkpoint latest.pth --moves 0,10,4 --with-mcts --num-simulations 200
python run.py play --human-player -1 --num-simulations 300
python run.py arena --games 50 --x-player checkpoint --o-player random
```

196
game.py Normal file
View File

@@ -0,0 +1,196 @@
import numpy as np
WIN_PATTERNS = [
(0, 1, 2),
(3, 4, 5),
(6, 7, 8),
(0, 3, 6),
(1, 4, 7),
(2, 5, 8),
(0, 4, 8),
(2, 4, 6),
]
class UltimateTicTacToe:
"""
A very, very simple game of ConnectX in which we have:
rows: 1
columns: 4
winNumber: 2
"""
def __init__(self):
self.cells = 81
self.board_width = 9
self.state_planes = 9
def get_init_board(self):
b = np.zeros((self.cells,), dtype=int)
return (b, None)
def get_board_size(self):
return (self.state_planes, self.board_width, self.board_width)
def get_action_size(self):
return self.cells
def get_next_state(self, board, player, action, verify_move=False):
if verify_move:
if self.get_valid_moves(board)[action] == 0:
return False
new_board_data = np.copy(board[0])
new_board_data[action] = player
next_board = ((action // 9) % 3) * 3 + (action % 3)
next_board = next_board if not self.is_board_full(new_board_data, next_board) else None
# Return the new game, but
# change the perspective of the game with negative
return ((new_board_data, next_board), -player)
def is_board_full(self, board_data, next_board):
return self._is_small_board_win(board_data, next_board, 1) or self._is_small_board_win(board_data, next_board, -1) or self._is_board_full(board_data, next_board)
def _small_board_cells(self, inner_board_idx):
row_block = inner_board_idx // 3
col_block = inner_board_idx % 3
base = row_block * 27 + col_block * 3
return [
base, base + 1, base + 2,
base + 9, base + 10, base + 11,
base + 18, base + 19, base + 20
]
def _is_board_full(self, board_data, next_board):
# Check if it is literally full
cells = self._small_board_cells(next_board)
for a in cells:
if board_data[a] == 0:
return False
return True
def _is_playable_small_board(self, board_data, inner_board_idx):
return not self.is_board_full(board_data, inner_board_idx)
def has_legal_moves(self, board):
valid_moves = self.get_valid_moves(board)
for i in valid_moves:
if i == 1:
return True
return False
def get_valid_moves(self, board):
# All moves are invalid by default
board_data, active_board = board
valid_moves = [0] * self.get_action_size()
if active_board is not None and not self._is_playable_small_board(board_data, active_board):
active_board = None
if active_board is None:
playable_boards = [
inner_board_idx
for inner_board_idx in range(9)
if self._is_playable_small_board(board_data, inner_board_idx)
]
for inner_board_idx in playable_boards:
for index in self._small_board_cells(inner_board_idx):
if board_data[index] == 0:
valid_moves[index] = 1
else:
for index in self._small_board_cells(active_board):
if board_data[index] == 0:
valid_moves[index] = 1
return valid_moves
def _is_small_board_win(self, board_data, inner_board_idx, player):
cells = self._small_board_cells(inner_board_idx)
for a, b, c in WIN_PATTERNS:
if board_data[cells[a]] == board_data[cells[b]] == board_data[cells[c]] == player:
return True
return False
def is_win(self, board, player):
board_data, _ = board
won = [self._is_small_board_win(board_data, i, player) for i in range(9)]
# Check if any winning combination is all 1s
for a, b, c in WIN_PATTERNS:
if won[a] and won[b] and won[c]:
return True
return False
def get_reward_for_player(self, board, player):
# return None if not ended, 1 if player 1 wins, -1 if player 1 lost
if self.is_win(board, player):
return 1
if self.is_win(board, -player):
return -1
if self.has_legal_moves(board):
return None
return 0
def get_canonical_board_data(self, board_data, player):
return player * board_data
def _small_board_mask(self, inner_board_idx):
mask = np.zeros((self.board_width, self.board_width), dtype=np.float32)
for index in self._small_board_cells(inner_board_idx):
row = index // self.board_width
col = index % self.board_width
mask[row, col] = 1.0
return mask
def encode_state(self, board):
board_data, active_board = board
board_grid = board_data.reshape(self.board_width, self.board_width)
current_stones = (board_grid == 1).astype(np.float32)
opponent_stones = (board_grid == -1).astype(np.float32)
empty_cells = (board_grid == 0).astype(np.float32)
legal_moves = np.array(self.get_valid_moves(board), dtype=np.float32).reshape(self.board_width, self.board_width)
active_board_mask = np.zeros((self.board_width, self.board_width), dtype=np.float32)
if active_board is not None and self._is_playable_small_board(board_data, active_board):
active_board_mask = self._small_board_mask(active_board)
current_won_boards = np.zeros((self.board_width, self.board_width), dtype=np.float32)
opponent_won_boards = np.zeros((self.board_width, self.board_width), dtype=np.float32)
playable_boards = np.zeros((self.board_width, self.board_width), dtype=np.float32)
for inner_board_idx in range(9):
board_mask = self._small_board_mask(inner_board_idx)
if self._is_small_board_win(board_data, inner_board_idx, 1):
current_won_boards += board_mask
elif self._is_small_board_win(board_data, inner_board_idx, -1):
opponent_won_boards += board_mask
if self._is_playable_small_board(board_data, inner_board_idx):
playable_boards += board_mask
move_count = np.count_nonzero(board_data) / self.cells
move_count_plane = np.full((self.board_width, self.board_width), move_count, dtype=np.float32)
return np.stack(
(
current_stones,
opponent_stones,
empty_cells,
legal_moves,
active_board_mask,
current_won_boards,
opponent_won_boards,
playable_boards,
move_count_plane,
),
axis=0,
)

190
mcts.py Normal file
View File

@@ -0,0 +1,190 @@
import torch
import math
import numpy as np
def ucb_score(parent, child):
"""
The score for an action that would transition between the parent and child.
"""
prior_score = child.prior * math.sqrt(parent.visit_count) / (child.visit_count + 1)
if child.visit_count > 0:
# The value of the child is from the perspective of the opposing player
value_score = -child.value()
else:
value_score = 0
return value_score + prior_score
class Node:
def __init__(self, prior, to_play):
self.visit_count = 0
self.to_play = to_play
self.prior = prior
self.value_sum = 0
self.children = {}
self.state = None
def expanded(self):
return len(self.children) > 0
def value(self):
if self.visit_count == 0:
return 0
return self.value_sum / self.visit_count
def select_action(self, temperature):
"""
Select action according to the visit count distribution and the temperature.
"""
visit_counts = np.array([child.visit_count for child in self.children.values()])
actions = [action for action in self.children.keys()]
if temperature == 0:
action = actions[np.argmax(visit_counts)]
elif temperature == float("inf"):
action = np.random.choice(actions)
else:
# See paper appendix Data Generation
visit_count_distribution = visit_counts ** (1 / temperature)
visit_count_distribution = visit_count_distribution / sum(visit_count_distribution)
action = np.random.choice(actions, p=visit_count_distribution)
return action
def select_child(self):
"""
Select the child with the highest UCB score.
"""
best_score = -np.inf
best_action = -1
best_child = None
for action, child in self.children.items():
score = ucb_score(self, child)
if score > best_score:
best_score = score
best_action = action
best_child = child
return best_action, best_child
def expand(self, state, to_play, action_probs):
"""
We expand a node and keep track of the prior policy probability given by neural network
"""
self.to_play = to_play
self.state = state
for a, prob in enumerate(action_probs):
if prob != 0:
self.children[a] = Node(prior=prob, to_play=self.to_play * -1)
def __repr__(self):
"""
Debugger pretty print node info
"""
prior = "{0:.2f}".format(self.prior)
return "{} Prior: {} Count: {} Value: {}".format(self.state.__str__(), prior, self.visit_count, self.value())
class MCTS:
def __init__(self, game, model, args):
self.game = game
self.model = model
self.args = args
def _masked_policy(self, state, model):
encoded_state = self.game.encode_state(state)
action_probs, value = model.predict(encoded_state)
valid_moves = np.array(self.game.get_valid_moves(state), dtype=np.float32)
action_probs = action_probs * valid_moves
total_prob = np.sum(action_probs)
total_valid = np.sum(valid_moves)
if total_valid <= 0:
return valid_moves, float(value)
if total_prob <= 0:
action_probs = valid_moves / total_valid
else:
action_probs /= total_prob
return action_probs, float(value)
def _add_exploration_noise(self, node):
alpha = self.args.get('root_dirichlet_alpha')
fraction = self.args.get('root_exploration_fraction')
if alpha is None or fraction is None or not node.children:
return
actions = list(node.children.keys())
noise = np.random.dirichlet([alpha] * len(actions))
for action, sample in zip(actions, noise):
child = node.children[action]
child.prior = child.prior * (1 - fraction) + sample * fraction
def run(self, model, state, to_play):
model = model or self.model
root = Node(0, to_play)
root.state = state
reward = self.game.get_reward_for_player(state, player=1)
if reward is not None:
root.value_sum = float(reward)
root.visit_count = 1
return root
# EXPAND root
action_probs, value = self._masked_policy(state, model)
root.expand(state, to_play, action_probs)
if not root.children:
root.value_sum = float(value)
root.visit_count = 1
return root
self._add_exploration_noise(root)
for _ in range(self.args['num_simulations']):
node = root
search_path = [node]
# SELECT
xp = False
while node.expanded():
action, node = node.select_child()
if node == None:
parent = search_path[-1]
self.backpropagate(search_path, value, parent.to_play * -1)
xp = True
break
search_path.append(node)
if xp:
continue
parent = search_path[-2]
state = parent.state
# Now we're at a leaf node and we would like to expand
# Players always play from their own perspective
next_state, _ = self.game.get_next_state(state, player=1, action=action)
# Get the board from the perspective of the other player
next_state_data, next_state_inner = next_state
next_state = (self.game.get_canonical_board_data(next_state_data, player=-1), next_state_inner)
# The value of the new state from the perspective of the other player
value = self.game.get_reward_for_player(next_state, player=1)
if value is None:
# If the game has not ended:
# EXPAND
action_probs, value = self._masked_policy(next_state, model)
node.expand(next_state, parent.to_play * -1, action_probs)
self.backpropagate(search_path, float(value), parent.to_play * -1)
return root
def backpropagate(self, search_path, value, to_play):
"""
At the end of a simulation, we propagate the evaluation all the way up the tree
to the root.
"""
for node in reversed(search_path):
node.value_sum += value if node.to_play == to_play else -value
node.visit_count += 1

81
model.py Normal file
View File

@@ -0,0 +1,81 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = x
x = F.relu(self.bn1(self.conv1(x)))
x = self.bn2(self.conv2(x))
return F.relu(x + residual)
class UltimateTicTacToeModel(nn.Module):
def __init__(self, board_size, action_size, device, channels=64, num_blocks=6):
super().__init__()
self.action_size = action_size
self.input_shape = board_size
self.input_channels = board_size[0]
self.board_height = board_size[1]
self.board_width = board_size[2]
self.device = torch.device(device)
self.stem = nn.Sequential(
nn.Conv2d(self.input_channels, channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(channels),
nn.ReLU(inplace=True),
)
self.residual_tower = nn.Sequential(*(ResidualBlock(channels) for _ in range(num_blocks)))
self.policy_head = nn.Sequential(
nn.Conv2d(channels, 32, kernel_size=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
)
self.policy_fc = nn.Linear(32 * self.board_height * self.board_width, self.action_size)
self.value_head = nn.Sequential(
nn.Conv2d(channels, 32, kernel_size=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
)
self.value_fc1 = nn.Linear(32 * self.board_height * self.board_width, 128)
self.value_fc2 = nn.Linear(128, 1)
self.to(self.device)
def forward(self, x):
x = x.view(-1, *self.input_shape)
x = self.stem(x)
x = self.residual_tower(x)
policy = self.policy_head(x)
policy = torch.flatten(policy, 1)
policy = self.policy_fc(policy)
value = self.value_head(x)
value = torch.flatten(value, 1)
value = F.relu(self.value_fc1(value))
value = torch.tanh(self.value_fc2(value))
return F.softmax(policy, dim=1), value
def predict(self, board):
board = torch.as_tensor(board, dtype=torch.float32, device=self.device)
board = board.view(1, *self.input_shape)
self.eval()
with torch.no_grad():
pi, v = self.forward(board)
return pi.detach().cpu().numpy()[0], float(v.item())

384
run.py Normal file
View File

@@ -0,0 +1,384 @@
import argparse
from pathlib import Path
import numpy as np
import torch
from game import UltimateTicTacToe
from mcts import MCTS
from model import UltimateTicTacToeModel
from trainer import Trainer
DEFAULT_ARGS = {
"num_simulations": 100,
"numIters": 50,
"numEps": 20,
"epochs": 5,
"batch_size": 64,
"lr": 5e-4,
"weight_decay": 1e-4,
"replay_buffer_size": 50000,
"value_loss_weight": 1.0,
"grad_clip_norm": 5.0,
"checkpoint_path": "latest.pth",
"temperature_threshold": 10,
"root_dirichlet_alpha": 0.3,
"root_exploration_fraction": 0.25,
"arena_compare_games": 6,
"arena_accept_threshold": 0.55,
"arena_compare_simulations": 8,
}
def get_device(device_arg):
if device_arg:
return device_arg
return "cuda" if torch.cuda.is_available() else "cpu"
def build_model(game, device):
return UltimateTicTacToeModel(
game.get_board_size(),
game.get_action_size(),
device,
)
def load_checkpoint(model, checkpoint_path, device, optimizer=None, required=True):
checkpoint = Path(checkpoint_path)
if not checkpoint.exists():
if required:
raise FileNotFoundError(f"Checkpoint not found: {checkpoint}")
return False
state = torch.load(checkpoint, map_location=device)
model.load_state_dict(state["state_dict"])
if optimizer is not None and "optimizer_state_dict" in state:
optimizer.load_state_dict(state["optimizer_state_dict"])
model.eval()
return True
def canonical_state(game, state, player):
board_data, active_board = state
return (game.get_canonical_board_data(board_data, player), active_board)
def apply_moves(game, moves):
state = game.get_init_board()
player = 1
for action in moves:
next_state = game.get_next_state(state, player, action, verify_move=True)
if next_state is False:
raise ValueError(f"Illegal move in sequence: {action}")
state, player = next_state
return state, player
def format_board(board_data):
symbols = {1: "X", -1: "O", 0: "."}
rows = []
for row in range(9):
cells = [symbols[int(board_data[row * 9 + col])] for col in range(9)]
groups = [" ".join(cells[idx:idx + 3]) for idx in (0, 3, 6)]
rows.append(" | ".join(groups))
if row in (2, 5):
rows.append("-" * 23)
return "\n".join(rows)
def top_policy_moves(policy, limit):
ranked = np.argsort(policy)[::-1][:limit]
return [(int(action), float(policy[action])) for action in ranked]
def parse_moves(text):
if not text:
return []
return [int(part.strip()) for part in text.split(",") if part.strip()]
def parse_action(text):
raw = text.strip().replace(",", " ").split()
if len(raw) == 1:
action = int(raw[0])
elif len(raw) == 2:
row, col = (int(value) for value in raw)
if not (0 <= row < 9 and 0 <= col < 9):
raise ValueError("Row and column must be in [0, 8].")
action = row * 9 + col
else:
raise ValueError("Enter either a flat move index or 'row col'.")
if not (0 <= action < 81):
raise ValueError("Move index must be in [0, 80].")
return action
def scalar_value(value):
return float(np.asarray(value).reshape(-1)[0])
def train_command(args):
device = get_device(args.device)
game = UltimateTicTacToe()
model = build_model(game, device)
train_args = dict(DEFAULT_ARGS)
train_args.update(
{
"num_simulations": args.num_simulations,
"numIters": args.num_iters,
"numEps": args.num_eps,
"epochs": args.epochs,
"batch_size": args.batch_size,
"lr": args.lr,
"weight_decay": args.weight_decay,
"replay_buffer_size": args.replay_buffer_size,
"value_loss_weight": args.value_loss_weight,
"grad_clip_norm": args.grad_clip_norm,
"checkpoint_path": args.checkpoint,
"temperature_threshold": args.temperature_threshold,
"root_dirichlet_alpha": args.root_dirichlet_alpha,
"root_exploration_fraction": args.root_exploration_fraction,
"arena_compare_games": args.arena_compare_games,
"arena_accept_threshold": args.arena_accept_threshold,
"arena_compare_simulations": args.arena_compare_simulations,
}
)
trainer = Trainer(game, model, train_args)
if args.resume:
load_checkpoint(model, args.checkpoint, device, optimizer=trainer.optimizer)
trainer.learn()
def eval_command(args):
device = get_device(args.device)
game = UltimateTicTacToe()
model = build_model(game, device)
load_checkpoint(model, args.checkpoint, device)
moves = parse_moves(args.moves)
state, player = apply_moves(game, moves)
current_state = canonical_state(game, state, player)
encoded = game.encode_state(current_state)
policy, value = model.predict(encoded)
legal_mask = np.array(game.get_valid_moves(state), dtype=np.float32)
policy = policy * legal_mask
if policy.sum() > 0:
policy = policy / policy.sum()
print("Board:")
print(format_board(state[0]))
print()
print(f"Side to move: {'X' if player == 1 else 'O'}")
print(f"Active small board: {state[1]}")
print(f"Model value: {scalar_value(value):.4f}")
print("Top policy moves:")
for action, prob in top_policy_moves(policy, args.top_k):
print(f" {action:2d} -> {prob:.4f}")
if args.with_mcts:
mcts_args = dict(DEFAULT_ARGS)
mcts_args.update(
{
"num_simulations": args.num_simulations,
"root_dirichlet_alpha": None,
"root_exploration_fraction": None,
}
)
root = MCTS(game, model, mcts_args).run(model, current_state, to_play=1)
action = root.select_action(temperature=0)
print(f"MCTS best move: {action}")
def ai_action(game, model, state, player, num_simulations):
current_state = canonical_state(game, state, player)
mcts_args = dict(DEFAULT_ARGS)
mcts_args.update(
{
"num_simulations": num_simulations,
"root_dirichlet_alpha": None,
"root_exploration_fraction": None,
}
)
root = MCTS(game, model, mcts_args).run(model, current_state, to_play=1)
return root.select_action(temperature=0)
def random_action(game, state):
legal_actions = [index for index, allowed in enumerate(game.get_valid_moves(state)) if allowed]
if not legal_actions:
raise ValueError("No legal actions available.")
return int(np.random.choice(legal_actions))
def load_player_model(game, checkpoint, device):
model = build_model(game, device)
load_checkpoint(model, checkpoint, device)
return model
def choose_action(game, player_kind, model, state, player, num_simulations):
if player_kind == "random":
return random_action(game, state)
return ai_action(game, model, state, player, num_simulations)
def play_match(game, x_kind, x_model, o_kind, o_model, num_simulations):
state = game.get_init_board()
player = 1
while True:
reward = game.get_reward_for_player(state, player)
if reward is not None:
if reward == 0:
return 0
return player if reward == 1 else -player
if player == 1:
action = choose_action(game, x_kind, x_model, state, player, num_simulations)
else:
action = choose_action(game, o_kind, o_model, state, player, num_simulations)
state, player = game.get_next_state(state, player, action)
def arena_command(args):
device = get_device(args.device)
game = UltimateTicTacToe()
x_model = None
o_model = None
if args.x_player == "checkpoint":
x_model = load_player_model(game, args.x_checkpoint, device)
if args.o_player == "checkpoint":
o_model = load_player_model(game, args.o_checkpoint, device)
results = {1: 0, -1: 0, 0: 0}
for _ in range(args.games):
winner = play_match(
game,
args.x_player,
x_model,
args.o_player,
o_model,
args.num_simulations,
)
results[winner] += 1
print(f"Games: {args.games}")
print(f"X ({args.x_player}) wins: {results[1]}")
print(f"O ({args.o_player}) wins: {results[-1]}")
print(f"Draws: {results[0]}")
def play_command(args):
device = get_device(args.device)
game = UltimateTicTacToe()
model = build_model(game, device)
load_checkpoint(model, args.checkpoint, device)
state = game.get_init_board()
player = 1
human_player = args.human_player
while True:
print()
print(format_board(state[0]))
print(f"Turn: {'X' if player == 1 else 'O'}")
print(f"Active small board: {state[1]}")
reward = game.get_reward_for_player(state, player)
if reward is not None:
if reward == 0:
print("Result: draw")
else:
winner = player if reward == 1 else -player
print(f"Winner: {'X' if winner == 1 else 'O'}")
return
valid_moves = game.get_valid_moves(state)
legal_actions = [index for index, allowed in enumerate(valid_moves) if allowed]
print(f"Legal moves: {legal_actions}")
if player == human_player:
while True:
try:
action = parse_action(input("Your move (index or 'row col'): "))
next_state = game.get_next_state(state, player, action, verify_move=True)
if next_state is False:
raise ValueError(f"Illegal move: {action}")
state, player = next_state
break
except ValueError as exc:
print(exc)
else:
action = ai_action(game, model, state, player, args.num_simulations)
print(f"AI move: {action}")
state, player = game.get_next_state(state, player, action)
def build_parser():
parser = argparse.ArgumentParser(description="Ultimate Tic-Tac-Toe Runner")
subparsers = parser.add_subparsers(dest="command", required=True)
train_parser = subparsers.add_parser("train", help="Train the model with self-play")
train_parser.add_argument("--device")
train_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
train_parser.add_argument("--resume", action="store_true")
train_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
train_parser.add_argument("--num-iters", type=int, default=DEFAULT_ARGS["numIters"])
train_parser.add_argument("--num-eps", type=int, default=DEFAULT_ARGS["numEps"])
train_parser.add_argument("--epochs", type=int, default=DEFAULT_ARGS["epochs"])
train_parser.add_argument("--batch-size", type=int, default=DEFAULT_ARGS["batch_size"])
train_parser.add_argument("--lr", type=float, default=DEFAULT_ARGS["lr"])
train_parser.add_argument("--weight-decay", type=float, default=DEFAULT_ARGS["weight_decay"])
train_parser.add_argument("--replay-buffer-size", type=int, default=DEFAULT_ARGS["replay_buffer_size"])
train_parser.add_argument("--value-loss-weight", type=float, default=DEFAULT_ARGS["value_loss_weight"])
train_parser.add_argument("--grad-clip-norm", type=float, default=DEFAULT_ARGS["grad_clip_norm"])
train_parser.add_argument("--temperature-threshold", type=int, default=DEFAULT_ARGS["temperature_threshold"])
train_parser.add_argument("--root-dirichlet-alpha", type=float, default=DEFAULT_ARGS["root_dirichlet_alpha"])
train_parser.add_argument("--root-exploration-fraction", type=float, default=DEFAULT_ARGS["root_exploration_fraction"])
train_parser.add_argument("--arena-compare-games", type=int, default=DEFAULT_ARGS["arena_compare_games"])
train_parser.add_argument("--arena-accept-threshold", type=float, default=DEFAULT_ARGS["arena_accept_threshold"])
train_parser.add_argument("--arena-compare-simulations", type=int, default=DEFAULT_ARGS["arena_compare_simulations"])
train_parser.set_defaults(func=train_command)
eval_parser = subparsers.add_parser("eval", help="Inspect a checkpoint on a position")
eval_parser.add_argument("--device")
eval_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
eval_parser.add_argument("--moves", default="", help="Comma-separated move sequence")
eval_parser.add_argument("--top-k", type=int, default=10)
eval_parser.add_argument("--with-mcts", action="store_true")
eval_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
eval_parser.set_defaults(func=eval_command)
play_parser = subparsers.add_parser("play", help="Play against the checkpoint")
play_parser.add_argument("--device")
play_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
play_parser.add_argument("--human-player", type=int, choices=[1, -1], default=1)
play_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
play_parser.set_defaults(func=play_command)
arena_parser = subparsers.add_parser("arena", help="Run repeated matches between agents")
arena_parser.add_argument("--device")
arena_parser.add_argument("--games", type=int, default=20)
arena_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
arena_parser.add_argument("--x-player", choices=["checkpoint", "random"], default="checkpoint")
arena_parser.add_argument("--o-player", choices=["checkpoint", "random"], default="random")
arena_parser.add_argument("--x-checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
arena_parser.add_argument("--o-checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
arena_parser.set_defaults(func=arena_command)
return parser
def main():
parser = build_parser()
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
main()

258
trainer.py Normal file
View File

@@ -0,0 +1,258 @@
import os
import math
import copy
import numpy as np
from collections import deque
from random import shuffle
from progressbar import ProgressBar, Percentage, Bar, ETA, AdaptiveETA, FormatLabel
import torch
import torch.optim as optim
from mcts import MCTS
loss = {'ploss': float('inf'), 'pkl': float('inf'), 'vloss': float('inf'), 'current': 0, 'max': 0}
class LossWidget:
def __call__(self, progress, data=None):
return (
f" {loss['current']}/{loss['max']} "
f"P.CE: {loss['ploss']:.4f}, P.KL: {loss['pkl']:.4f}, V.Loss: {loss['vloss']:.4f}"
)
class Trainer:
def __init__(self, game, model, args):
self.game = game
self.model = model
self.args = args
self.mcts = MCTS(self.game, self.model, self.args)
self.optimizer = optim.Adam(
self.model.parameters(),
lr=self.args.get('lr', 5e-4),
weight_decay=self.args.get('weight_decay', 1e-4),
)
self.replay_buffer = deque(maxlen=self.args.get('replay_buffer_size', 50000))
def _ai_action(self, model, state, player):
board_data, active_board = state
canonical_state = (self.game.get_canonical_board_data(board_data, player), active_board)
mcts_args = dict(self.args)
mcts_args.update(
{
'num_simulations': self.args.get('arena_compare_simulations', self.args['num_simulations']),
'root_dirichlet_alpha': None,
'root_exploration_fraction': None,
}
)
root = MCTS(self.game, model, mcts_args).run(model, canonical_state, to_play=1)
return root.select_action(temperature=0)
def _play_arena_game(self, x_model, o_model):
state = self.game.get_init_board()
current_player = 1
while True:
reward = self.game.get_reward_for_player(state, current_player)
if reward is not None:
if reward == 0:
return 0
return current_player if reward == 1 else -current_player
model = x_model if current_player == 1 else o_model
action = self._ai_action(model, state, current_player)
state, current_player = self.game.get_next_state(state, current_player, action)
def evaluate_candidate(self, candidate_model, reference_model):
games = self.args.get('arena_compare_games', 0)
if games <= 0:
return True, 0.5
candidate_points = 0.0
required_points = self.args.get('arena_accept_threshold', 0.55) * games
candidate_first_games = (games + 1) // 2
candidate_second_games = games // 2
games_played = 0
for _ in range(candidate_first_games):
winner = self._play_arena_game(candidate_model, reference_model)
if winner == 1:
candidate_points += 1.0
elif winner == 0:
candidate_points += 0.5
games_played += 1
remaining_games = games - games_played
if candidate_points >= required_points:
return True, candidate_points / games
if candidate_points + remaining_games < required_points:
return False, candidate_points / games
for _ in range(candidate_second_games):
winner = self._play_arena_game(reference_model, candidate_model)
if winner == -1:
candidate_points += 1.0
elif winner == 0:
candidate_points += 0.5
games_played += 1
remaining_games = games - games_played
if candidate_points >= required_points:
return True, candidate_points / games
if candidate_points + remaining_games < required_points:
return False, candidate_points / games
score = candidate_points / games
return score >= self.args.get('arena_accept_threshold', 0.55), score
def exceute_episode(self):
train_examples = []
current_player = 1
state = self.game.get_init_board()
episode_step = 0
while True:
board_data, state_inner_board = state
cannonical_board_data = self.game.get_canonical_board_data(board_data, current_player)
canonical_board = (cannonical_board_data, state_inner_board)
self.mcts = MCTS(self.game, self.model, self.args)
root = self.mcts.run(self.model, canonical_board, to_play=1)
action_probs = np.zeros(self.game.get_action_size(), dtype=np.float32)
for k, v in root.children.items():
action_probs[k] = v.visit_count
action_probs = action_probs / np.sum(action_probs)
encoded_state = self.game.encode_state(canonical_board)
train_examples.append((encoded_state, current_player, action_probs))
temperature_threshold = self.args.get('temperature_threshold', 10)
temperature = 1 if episode_step < temperature_threshold else 0
action = root.select_action(temperature=temperature)
state, current_player = self.game.get_next_state(state, current_player, action)
reward = self.game.get_reward_for_player(state, current_player)
episode_step += 1
if reward is not None:
ret = []
for hist_state, hist_current_player, hist_action_probs in train_examples:
# [Board, currentPlayer, actionProbabilities, Reward]
ret.append((hist_state, hist_action_probs, reward * ((-1) ** (hist_current_player != current_player))))
return ret
def learn(self):
widgets = [Percentage(), Bar(), AdaptiveETA(), LossWidget()]
pbar = ProgressBar(max_value=self.args['numIters'], widgets=widgets)
pbar.update(0)
for i in range(1, self.args['numIters'] + 1):
# print("{}/{}".format(i, self.args['numIters']))
train_examples = []
for eps in range(self.args['numEps']):
iteration_train_examples = self.exceute_episode()
train_examples.extend(iteration_train_examples)
self.replay_buffer.extend(train_examples)
training_examples = list(self.replay_buffer)
shuffle(training_examples)
shuffle(train_examples)
reference_model = copy.deepcopy(self.model)
reference_model.eval()
reference_state_dict = copy.deepcopy(self.model.state_dict())
reference_optimizer_state = copy.deepcopy(self.optimizer.state_dict())
results = self.train(training_examples)
accepted, arena_score = self.evaluate_candidate(self.model, reference_model)
if accepted:
filename = self.args['checkpoint_path']
self.save_checkpoint(folder=".", filename=filename)
else:
self.model.load_state_dict(reference_state_dict)
self.optimizer.load_state_dict(reference_optimizer_state)
# print((float(results[0]), float(results[1])))
loss['ploss'] = float(results[0])
loss['pkl'] = float(results[1])
loss['vloss'] = float(results[2])
loss['current'] = i
loss['max'] = self.args['numIters']
# loss_widget.update(float(results[0]), float(results[1]))
pbar.update(i)
pbar.finish()
def train(self, examples):
pi_losses = []
pi_kls = []
v_losses = []
device = self.model.device
for epoch in range(self.args['epochs']):
self.model.train()
shuffled_indices = np.random.permutation(len(examples))
for batch_start in range(0, len(examples), self.args['batch_size']):
sample_ids = shuffled_indices[batch_start:batch_start + self.args['batch_size']]
boards, pis, vs = list(zip(*[examples[i] for i in sample_ids]))
boards = torch.FloatTensor(np.array(boards).astype(np.float32))
target_pis = torch.FloatTensor(np.array(pis))
target_vs = torch.FloatTensor(np.array(vs).astype(np.float32))
# predict
boards = boards.contiguous().to(device)
target_pis = target_pis.contiguous().to(device)
target_vs = target_vs.contiguous().to(device)
# compute output
out_pi, out_v = self.model(boards)
l_pi = self.loss_pi(target_pis, out_pi)
l_pi_kl = self.loss_pi_kl(target_pis, out_pi)
l_v = self.loss_v(target_vs, out_v)
total_loss = l_pi_kl + self.args.get('value_loss_weight', 1.0) * l_v
pi_losses.append(float(l_pi.detach()))
pi_kls.append(float(l_pi_kl.detach()))
v_losses.append(float(l_v.detach()))
self.optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.get('grad_clip_norm', 5.0))
self.optimizer.step()
# print()
# print("Policy Loss", np.mean(pi_losses))
# print("Value Loss", np.mean(v_losses))
return (np.mean(pi_losses), np.mean(pi_kls), np.mean(v_losses))
# print("Examples:")
# print(out_pi[0].detach())
# print(target_pis[0])
def loss_pi(self, targets, outputs):
loss = -(targets * torch.log(outputs.clamp_min(1e-8))).sum(dim=1)
return loss.mean()
def loss_pi_kl(self, targets, outputs):
target_log = torch.log(targets.clamp_min(1e-8))
output_log = torch.log(outputs.clamp_min(1e-8))
loss = (targets * (target_log - output_log)).sum(dim=1)
return loss.mean()
def loss_v(self, targets, outputs):
loss = torch.sum((targets-outputs.view(-1))**2)/targets.size()[0]
return loss
def save_checkpoint(self, folder, filename):
if not os.path.exists(folder):
os.mkdir(folder)
filepath = os.path.join(folder, filename)
torch.save({
'state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'args': self.args,
}, filepath)