First commit
This commit is contained in:
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
__pycache__/
|
||||||
|
.ipynb_checkpoints/
|
||||||
|
*.pth
|
||||||
|
|
||||||
60
README.md
Normal file
60
README.md
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
# Ultimate Tic Tac Toe Deep Learning Bot
|
||||||
|
|
||||||
|
**Usage**
|
||||||
|
Run `python run.py --help` for help.
|
||||||
|
|
||||||
|
**Shared flags**
|
||||||
|
- --device: Torch device string. If omitted, it auto-picks "cuda" when available, otherwise "cpu".
|
||||||
|
- --checkpoint: Path to the model checkpoint file. Default is latest.pth. It is loaded for eval, play, and checkpoint-based arena,
|
||||||
|
and used as the save path for accepted training checkpoints.
|
||||||
|
|
||||||
|
**Training parameters**
|
||||||
|
- --resume: Loads model and optimizer state from --checkpoint before continuing training.
|
||||||
|
- --num-simulations default 100: MCTS rollouts per move during self-play. Higher is stronger/slower.
|
||||||
|
- --num-iters default 50: Number of outer training iterations. Each iteration generates new self-play games, trains, then arena-tests
|
||||||
|
the new model.
|
||||||
|
- --num-eps default 20: Self-play games per iteration.
|
||||||
|
- --epochs default 5: Passes over the current replay-buffer training set per iteration.
|
||||||
|
- --batch-size default 64: Mini-batch size for gradient updates.
|
||||||
|
- --lr default 5e-4: Adam learning rate.
|
||||||
|
- --weight-decay default 1e-4: Adam weight decay (L2-style regularization).
|
||||||
|
- --replay-buffer-size default 50000: Maximum number of training examples retained across iterations. Older examples are dropped.
|
||||||
|
- --value-loss-weight default 1.0: Multiplier on the value-head loss in total training loss. Total loss is policy_KL +
|
||||||
|
value_loss_weight * value_loss.
|
||||||
|
- --grad-clip-norm default 5.0: Global gradient norm clipping threshold before optimizer step.
|
||||||
|
- --temperature-threshold default 10: In self-play, moves before this step use stochastic sampling from MCTS visit counts; later
|
||||||
|
moves use greedy selection.
|
||||||
|
- --root-dirichlet-alpha default 0.3: Dirichlet noise alpha added to root priors during self-play MCTS to force exploration.
|
||||||
|
- --root-exploration-fraction default 0.25: How much of that root prior is replaced by Dirichlet noise.
|
||||||
|
- --arena-compare-games default 6: Number of head-to-head games between candidate and previous model after each iteration. If <= 0,
|
||||||
|
every candidate is accepted.
|
||||||
|
- --arena-accept-threshold default 0.55: Minimum average points needed in arena to keep the new model. Win = 1, draw = 0.5.
|
||||||
|
- --arena-compare-simulations default 8: MCTS simulations per move during those arena comparison games. Separate from self-play
|
||||||
|
--num-simulations.
|
||||||
|
|
||||||
|
**Evaluation parameters**
|
||||||
|
|
||||||
|
- --moves default "": Comma-separated move list to reach a position from the starting board, e.g. 0,10,4.
|
||||||
|
- --top-k default 10: How many highest-probability legal moves to print from the model policy.
|
||||||
|
- --with-mcts: Also run MCTS on that position and print the best move, instead of only raw network policy/value.
|
||||||
|
- --num-simulations default 100: Only matters with --with-mcts; controls MCTS search depth for that evaluation.
|
||||||
|
|
||||||
|
**Play parameters**
|
||||||
|
|
||||||
|
- --human-player default 1: Which side you control. 1 means X, -1 means O.
|
||||||
|
- --num-simulations default 100: MCTS simulations the AI uses for each move.
|
||||||
|
|
||||||
|
**Arena parameters**
|
||||||
|
- --games default 20: Number of matches to run.
|
||||||
|
- --num-simulations default 100: MCTS simulations per move for checkpoint-based players.
|
||||||
|
- --x-player / --o-player: Either checkpoint or random. Chooses the agent type for each side.
|
||||||
|
- --x-checkpoint / --o-checkpoint: Checkpoint path for that side when its player type is checkpoint. Ignored for random.
|
||||||
|
|
||||||
|
A few practical examples:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python run.py train --num-iters 100 --num-eps 50 --resume
|
||||||
|
python run.py eval --checkpoint latest.pth --moves 0,10,4 --with-mcts --num-simulations 200
|
||||||
|
python run.py play --human-player -1 --num-simulations 300
|
||||||
|
python run.py arena --games 50 --x-player checkpoint --o-player random
|
||||||
|
```
|
||||||
196
game.py
Normal file
196
game.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
WIN_PATTERNS = [
|
||||||
|
(0, 1, 2),
|
||||||
|
(3, 4, 5),
|
||||||
|
(6, 7, 8),
|
||||||
|
(0, 3, 6),
|
||||||
|
(1, 4, 7),
|
||||||
|
(2, 5, 8),
|
||||||
|
(0, 4, 8),
|
||||||
|
(2, 4, 6),
|
||||||
|
]
|
||||||
|
|
||||||
|
class UltimateTicTacToe:
|
||||||
|
"""
|
||||||
|
A very, very simple game of ConnectX in which we have:
|
||||||
|
rows: 1
|
||||||
|
columns: 4
|
||||||
|
winNumber: 2
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.cells = 81
|
||||||
|
self.board_width = 9
|
||||||
|
self.state_planes = 9
|
||||||
|
|
||||||
|
def get_init_board(self):
|
||||||
|
b = np.zeros((self.cells,), dtype=int)
|
||||||
|
return (b, None)
|
||||||
|
|
||||||
|
def get_board_size(self):
|
||||||
|
return (self.state_planes, self.board_width, self.board_width)
|
||||||
|
|
||||||
|
def get_action_size(self):
|
||||||
|
return self.cells
|
||||||
|
|
||||||
|
def get_next_state(self, board, player, action, verify_move=False):
|
||||||
|
if verify_move:
|
||||||
|
if self.get_valid_moves(board)[action] == 0:
|
||||||
|
return False
|
||||||
|
new_board_data = np.copy(board[0])
|
||||||
|
new_board_data[action] = player
|
||||||
|
|
||||||
|
next_board = ((action // 9) % 3) * 3 + (action % 3)
|
||||||
|
next_board = next_board if not self.is_board_full(new_board_data, next_board) else None
|
||||||
|
|
||||||
|
# Return the new game, but
|
||||||
|
# change the perspective of the game with negative
|
||||||
|
return ((new_board_data, next_board), -player)
|
||||||
|
|
||||||
|
def is_board_full(self, board_data, next_board):
|
||||||
|
return self._is_small_board_win(board_data, next_board, 1) or self._is_small_board_win(board_data, next_board, -1) or self._is_board_full(board_data, next_board)
|
||||||
|
|
||||||
|
def _small_board_cells(self, inner_board_idx):
|
||||||
|
row_block = inner_board_idx // 3
|
||||||
|
col_block = inner_board_idx % 3
|
||||||
|
|
||||||
|
base = row_block * 27 + col_block * 3
|
||||||
|
|
||||||
|
return [
|
||||||
|
base, base + 1, base + 2,
|
||||||
|
base + 9, base + 10, base + 11,
|
||||||
|
base + 18, base + 19, base + 20
|
||||||
|
]
|
||||||
|
|
||||||
|
def _is_board_full(self, board_data, next_board):
|
||||||
|
# Check if it is literally full
|
||||||
|
cells = self._small_board_cells(next_board)
|
||||||
|
|
||||||
|
for a in cells:
|
||||||
|
if board_data[a] == 0:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _is_playable_small_board(self, board_data, inner_board_idx):
|
||||||
|
return not self.is_board_full(board_data, inner_board_idx)
|
||||||
|
|
||||||
|
def has_legal_moves(self, board):
|
||||||
|
valid_moves = self.get_valid_moves(board)
|
||||||
|
for i in valid_moves:
|
||||||
|
if i == 1:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_valid_moves(self, board):
|
||||||
|
# All moves are invalid by default
|
||||||
|
board_data, active_board = board
|
||||||
|
valid_moves = [0] * self.get_action_size()
|
||||||
|
|
||||||
|
if active_board is not None and not self._is_playable_small_board(board_data, active_board):
|
||||||
|
active_board = None
|
||||||
|
|
||||||
|
if active_board is None:
|
||||||
|
playable_boards = [
|
||||||
|
inner_board_idx
|
||||||
|
for inner_board_idx in range(9)
|
||||||
|
if self._is_playable_small_board(board_data, inner_board_idx)
|
||||||
|
]
|
||||||
|
for inner_board_idx in playable_boards:
|
||||||
|
for index in self._small_board_cells(inner_board_idx):
|
||||||
|
if board_data[index] == 0:
|
||||||
|
valid_moves[index] = 1
|
||||||
|
else:
|
||||||
|
for index in self._small_board_cells(active_board):
|
||||||
|
if board_data[index] == 0:
|
||||||
|
valid_moves[index] = 1
|
||||||
|
|
||||||
|
return valid_moves
|
||||||
|
|
||||||
|
def _is_small_board_win(self, board_data, inner_board_idx, player):
|
||||||
|
cells = self._small_board_cells(inner_board_idx)
|
||||||
|
|
||||||
|
for a, b, c in WIN_PATTERNS:
|
||||||
|
if board_data[cells[a]] == board_data[cells[b]] == board_data[cells[c]] == player:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_win(self, board, player):
|
||||||
|
board_data, _ = board
|
||||||
|
won = [self._is_small_board_win(board_data, i, player) for i in range(9)]
|
||||||
|
|
||||||
|
# Check if any winning combination is all 1s
|
||||||
|
for a, b, c in WIN_PATTERNS:
|
||||||
|
if won[a] and won[b] and won[c]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_reward_for_player(self, board, player):
|
||||||
|
# return None if not ended, 1 if player 1 wins, -1 if player 1 lost
|
||||||
|
|
||||||
|
if self.is_win(board, player):
|
||||||
|
return 1
|
||||||
|
if self.is_win(board, -player):
|
||||||
|
return -1
|
||||||
|
if self.has_legal_moves(board):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def get_canonical_board_data(self, board_data, player):
|
||||||
|
return player * board_data
|
||||||
|
|
||||||
|
def _small_board_mask(self, inner_board_idx):
|
||||||
|
mask = np.zeros((self.board_width, self.board_width), dtype=np.float32)
|
||||||
|
for index in self._small_board_cells(inner_board_idx):
|
||||||
|
row = index // self.board_width
|
||||||
|
col = index % self.board_width
|
||||||
|
mask[row, col] = 1.0
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def encode_state(self, board):
|
||||||
|
board_data, active_board = board
|
||||||
|
board_grid = board_data.reshape(self.board_width, self.board_width)
|
||||||
|
|
||||||
|
current_stones = (board_grid == 1).astype(np.float32)
|
||||||
|
opponent_stones = (board_grid == -1).astype(np.float32)
|
||||||
|
empty_cells = (board_grid == 0).astype(np.float32)
|
||||||
|
legal_moves = np.array(self.get_valid_moves(board), dtype=np.float32).reshape(self.board_width, self.board_width)
|
||||||
|
|
||||||
|
active_board_mask = np.zeros((self.board_width, self.board_width), dtype=np.float32)
|
||||||
|
if active_board is not None and self._is_playable_small_board(board_data, active_board):
|
||||||
|
active_board_mask = self._small_board_mask(active_board)
|
||||||
|
|
||||||
|
current_won_boards = np.zeros((self.board_width, self.board_width), dtype=np.float32)
|
||||||
|
opponent_won_boards = np.zeros((self.board_width, self.board_width), dtype=np.float32)
|
||||||
|
playable_boards = np.zeros((self.board_width, self.board_width), dtype=np.float32)
|
||||||
|
|
||||||
|
for inner_board_idx in range(9):
|
||||||
|
board_mask = self._small_board_mask(inner_board_idx)
|
||||||
|
if self._is_small_board_win(board_data, inner_board_idx, 1):
|
||||||
|
current_won_boards += board_mask
|
||||||
|
elif self._is_small_board_win(board_data, inner_board_idx, -1):
|
||||||
|
opponent_won_boards += board_mask
|
||||||
|
|
||||||
|
if self._is_playable_small_board(board_data, inner_board_idx):
|
||||||
|
playable_boards += board_mask
|
||||||
|
|
||||||
|
move_count = np.count_nonzero(board_data) / self.cells
|
||||||
|
move_count_plane = np.full((self.board_width, self.board_width), move_count, dtype=np.float32)
|
||||||
|
|
||||||
|
return np.stack(
|
||||||
|
(
|
||||||
|
current_stones,
|
||||||
|
opponent_stones,
|
||||||
|
empty_cells,
|
||||||
|
legal_moves,
|
||||||
|
active_board_mask,
|
||||||
|
current_won_boards,
|
||||||
|
opponent_won_boards,
|
||||||
|
playable_boards,
|
||||||
|
move_count_plane,
|
||||||
|
),
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
190
mcts.py
Normal file
190
mcts.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def ucb_score(parent, child):
|
||||||
|
"""
|
||||||
|
The score for an action that would transition between the parent and child.
|
||||||
|
"""
|
||||||
|
prior_score = child.prior * math.sqrt(parent.visit_count) / (child.visit_count + 1)
|
||||||
|
if child.visit_count > 0:
|
||||||
|
# The value of the child is from the perspective of the opposing player
|
||||||
|
value_score = -child.value()
|
||||||
|
else:
|
||||||
|
value_score = 0
|
||||||
|
|
||||||
|
return value_score + prior_score
|
||||||
|
|
||||||
|
|
||||||
|
class Node:
|
||||||
|
def __init__(self, prior, to_play):
|
||||||
|
self.visit_count = 0
|
||||||
|
self.to_play = to_play
|
||||||
|
self.prior = prior
|
||||||
|
self.value_sum = 0
|
||||||
|
self.children = {}
|
||||||
|
self.state = None
|
||||||
|
|
||||||
|
def expanded(self):
|
||||||
|
return len(self.children) > 0
|
||||||
|
|
||||||
|
def value(self):
|
||||||
|
if self.visit_count == 0:
|
||||||
|
return 0
|
||||||
|
return self.value_sum / self.visit_count
|
||||||
|
|
||||||
|
def select_action(self, temperature):
|
||||||
|
"""
|
||||||
|
Select action according to the visit count distribution and the temperature.
|
||||||
|
"""
|
||||||
|
visit_counts = np.array([child.visit_count for child in self.children.values()])
|
||||||
|
actions = [action for action in self.children.keys()]
|
||||||
|
if temperature == 0:
|
||||||
|
action = actions[np.argmax(visit_counts)]
|
||||||
|
elif temperature == float("inf"):
|
||||||
|
action = np.random.choice(actions)
|
||||||
|
else:
|
||||||
|
# See paper appendix Data Generation
|
||||||
|
visit_count_distribution = visit_counts ** (1 / temperature)
|
||||||
|
visit_count_distribution = visit_count_distribution / sum(visit_count_distribution)
|
||||||
|
action = np.random.choice(actions, p=visit_count_distribution)
|
||||||
|
|
||||||
|
return action
|
||||||
|
|
||||||
|
def select_child(self):
|
||||||
|
"""
|
||||||
|
Select the child with the highest UCB score.
|
||||||
|
"""
|
||||||
|
best_score = -np.inf
|
||||||
|
best_action = -1
|
||||||
|
best_child = None
|
||||||
|
|
||||||
|
for action, child in self.children.items():
|
||||||
|
score = ucb_score(self, child)
|
||||||
|
if score > best_score:
|
||||||
|
best_score = score
|
||||||
|
best_action = action
|
||||||
|
best_child = child
|
||||||
|
|
||||||
|
return best_action, best_child
|
||||||
|
|
||||||
|
def expand(self, state, to_play, action_probs):
|
||||||
|
"""
|
||||||
|
We expand a node and keep track of the prior policy probability given by neural network
|
||||||
|
"""
|
||||||
|
self.to_play = to_play
|
||||||
|
self.state = state
|
||||||
|
for a, prob in enumerate(action_probs):
|
||||||
|
if prob != 0:
|
||||||
|
self.children[a] = Node(prior=prob, to_play=self.to_play * -1)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
"""
|
||||||
|
Debugger pretty print node info
|
||||||
|
"""
|
||||||
|
prior = "{0:.2f}".format(self.prior)
|
||||||
|
return "{} Prior: {} Count: {} Value: {}".format(self.state.__str__(), prior, self.visit_count, self.value())
|
||||||
|
|
||||||
|
|
||||||
|
class MCTS:
|
||||||
|
|
||||||
|
def __init__(self, game, model, args):
|
||||||
|
self.game = game
|
||||||
|
self.model = model
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def _masked_policy(self, state, model):
|
||||||
|
encoded_state = self.game.encode_state(state)
|
||||||
|
action_probs, value = model.predict(encoded_state)
|
||||||
|
valid_moves = np.array(self.game.get_valid_moves(state), dtype=np.float32)
|
||||||
|
action_probs = action_probs * valid_moves
|
||||||
|
total_prob = np.sum(action_probs)
|
||||||
|
total_valid = np.sum(valid_moves)
|
||||||
|
if total_valid <= 0:
|
||||||
|
return valid_moves, float(value)
|
||||||
|
if total_prob <= 0:
|
||||||
|
action_probs = valid_moves / total_valid
|
||||||
|
else:
|
||||||
|
action_probs /= total_prob
|
||||||
|
return action_probs, float(value)
|
||||||
|
|
||||||
|
def _add_exploration_noise(self, node):
|
||||||
|
alpha = self.args.get('root_dirichlet_alpha')
|
||||||
|
fraction = self.args.get('root_exploration_fraction')
|
||||||
|
if alpha is None or fraction is None or not node.children:
|
||||||
|
return
|
||||||
|
|
||||||
|
actions = list(node.children.keys())
|
||||||
|
noise = np.random.dirichlet([alpha] * len(actions))
|
||||||
|
for action, sample in zip(actions, noise):
|
||||||
|
child = node.children[action]
|
||||||
|
child.prior = child.prior * (1 - fraction) + sample * fraction
|
||||||
|
|
||||||
|
def run(self, model, state, to_play):
|
||||||
|
model = model or self.model
|
||||||
|
|
||||||
|
root = Node(0, to_play)
|
||||||
|
root.state = state
|
||||||
|
|
||||||
|
reward = self.game.get_reward_for_player(state, player=1)
|
||||||
|
if reward is not None:
|
||||||
|
root.value_sum = float(reward)
|
||||||
|
root.visit_count = 1
|
||||||
|
return root
|
||||||
|
|
||||||
|
# EXPAND root
|
||||||
|
action_probs, value = self._masked_policy(state, model)
|
||||||
|
root.expand(state, to_play, action_probs)
|
||||||
|
if not root.children:
|
||||||
|
root.value_sum = float(value)
|
||||||
|
root.visit_count = 1
|
||||||
|
return root
|
||||||
|
self._add_exploration_noise(root)
|
||||||
|
|
||||||
|
for _ in range(self.args['num_simulations']):
|
||||||
|
node = root
|
||||||
|
search_path = [node]
|
||||||
|
|
||||||
|
# SELECT
|
||||||
|
xp = False
|
||||||
|
while node.expanded():
|
||||||
|
action, node = node.select_child()
|
||||||
|
if node == None:
|
||||||
|
parent = search_path[-1]
|
||||||
|
self.backpropagate(search_path, value, parent.to_play * -1)
|
||||||
|
xp = True
|
||||||
|
break
|
||||||
|
search_path.append(node)
|
||||||
|
if xp:
|
||||||
|
continue
|
||||||
|
parent = search_path[-2]
|
||||||
|
state = parent.state
|
||||||
|
# Now we're at a leaf node and we would like to expand
|
||||||
|
# Players always play from their own perspective
|
||||||
|
next_state, _ = self.game.get_next_state(state, player=1, action=action)
|
||||||
|
# Get the board from the perspective of the other player
|
||||||
|
next_state_data, next_state_inner = next_state
|
||||||
|
|
||||||
|
next_state = (self.game.get_canonical_board_data(next_state_data, player=-1), next_state_inner)
|
||||||
|
|
||||||
|
# The value of the new state from the perspective of the other player
|
||||||
|
value = self.game.get_reward_for_player(next_state, player=1)
|
||||||
|
if value is None:
|
||||||
|
# If the game has not ended:
|
||||||
|
# EXPAND
|
||||||
|
action_probs, value = self._masked_policy(next_state, model)
|
||||||
|
node.expand(next_state, parent.to_play * -1, action_probs)
|
||||||
|
|
||||||
|
self.backpropagate(search_path, float(value), parent.to_play * -1)
|
||||||
|
|
||||||
|
return root
|
||||||
|
|
||||||
|
def backpropagate(self, search_path, value, to_play):
|
||||||
|
"""
|
||||||
|
At the end of a simulation, we propagate the evaluation all the way up the tree
|
||||||
|
to the root.
|
||||||
|
"""
|
||||||
|
for node in reversed(search_path):
|
||||||
|
node.value_sum += value if node.to_play == to_play else -value
|
||||||
|
node.visit_count += 1
|
||||||
81
model.py
Normal file
81
model.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(channels)
|
||||||
|
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
x = F.relu(self.bn1(self.conv1(x)))
|
||||||
|
x = self.bn2(self.conv2(x))
|
||||||
|
return F.relu(x + residual)
|
||||||
|
|
||||||
|
|
||||||
|
class UltimateTicTacToeModel(nn.Module):
|
||||||
|
def __init__(self, board_size, action_size, device, channels=64, num_blocks=6):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.action_size = action_size
|
||||||
|
self.input_shape = board_size
|
||||||
|
self.input_channels = board_size[0]
|
||||||
|
self.board_height = board_size[1]
|
||||||
|
self.board_width = board_size[2]
|
||||||
|
self.device = torch.device(device)
|
||||||
|
|
||||||
|
self.stem = nn.Sequential(
|
||||||
|
nn.Conv2d(self.input_channels, channels, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(channels),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
)
|
||||||
|
self.residual_tower = nn.Sequential(*(ResidualBlock(channels) for _ in range(num_blocks)))
|
||||||
|
|
||||||
|
self.policy_head = nn.Sequential(
|
||||||
|
nn.Conv2d(channels, 32, kernel_size=1, bias=False),
|
||||||
|
nn.BatchNorm2d(32),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
)
|
||||||
|
self.policy_fc = nn.Linear(32 * self.board_height * self.board_width, self.action_size)
|
||||||
|
|
||||||
|
self.value_head = nn.Sequential(
|
||||||
|
nn.Conv2d(channels, 32, kernel_size=1, bias=False),
|
||||||
|
nn.BatchNorm2d(32),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
)
|
||||||
|
self.value_fc1 = nn.Linear(32 * self.board_height * self.board_width, 128)
|
||||||
|
self.value_fc2 = nn.Linear(128, 1)
|
||||||
|
|
||||||
|
self.to(self.device)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.view(-1, *self.input_shape)
|
||||||
|
x = self.stem(x)
|
||||||
|
x = self.residual_tower(x)
|
||||||
|
|
||||||
|
policy = self.policy_head(x)
|
||||||
|
policy = torch.flatten(policy, 1)
|
||||||
|
policy = self.policy_fc(policy)
|
||||||
|
|
||||||
|
value = self.value_head(x)
|
||||||
|
value = torch.flatten(value, 1)
|
||||||
|
value = F.relu(self.value_fc1(value))
|
||||||
|
value = torch.tanh(self.value_fc2(value))
|
||||||
|
|
||||||
|
return F.softmax(policy, dim=1), value
|
||||||
|
|
||||||
|
def predict(self, board):
|
||||||
|
board = torch.as_tensor(board, dtype=torch.float32, device=self.device)
|
||||||
|
board = board.view(1, *self.input_shape)
|
||||||
|
self.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
pi, v = self.forward(board)
|
||||||
|
|
||||||
|
return pi.detach().cpu().numpy()[0], float(v.item())
|
||||||
384
run.py
Normal file
384
run.py
Normal file
@@ -0,0 +1,384 @@
|
|||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from game import UltimateTicTacToe
|
||||||
|
from mcts import MCTS
|
||||||
|
from model import UltimateTicTacToeModel
|
||||||
|
from trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_ARGS = {
|
||||||
|
"num_simulations": 100,
|
||||||
|
"numIters": 50,
|
||||||
|
"numEps": 20,
|
||||||
|
"epochs": 5,
|
||||||
|
"batch_size": 64,
|
||||||
|
"lr": 5e-4,
|
||||||
|
"weight_decay": 1e-4,
|
||||||
|
"replay_buffer_size": 50000,
|
||||||
|
"value_loss_weight": 1.0,
|
||||||
|
"grad_clip_norm": 5.0,
|
||||||
|
"checkpoint_path": "latest.pth",
|
||||||
|
"temperature_threshold": 10,
|
||||||
|
"root_dirichlet_alpha": 0.3,
|
||||||
|
"root_exploration_fraction": 0.25,
|
||||||
|
"arena_compare_games": 6,
|
||||||
|
"arena_accept_threshold": 0.55,
|
||||||
|
"arena_compare_simulations": 8,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_device(device_arg):
|
||||||
|
if device_arg:
|
||||||
|
return device_arg
|
||||||
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def build_model(game, device):
|
||||||
|
return UltimateTicTacToeModel(
|
||||||
|
game.get_board_size(),
|
||||||
|
game.get_action_size(),
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(model, checkpoint_path, device, optimizer=None, required=True):
|
||||||
|
checkpoint = Path(checkpoint_path)
|
||||||
|
if not checkpoint.exists():
|
||||||
|
if required:
|
||||||
|
raise FileNotFoundError(f"Checkpoint not found: {checkpoint}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
state = torch.load(checkpoint, map_location=device)
|
||||||
|
model.load_state_dict(state["state_dict"])
|
||||||
|
if optimizer is not None and "optimizer_state_dict" in state:
|
||||||
|
optimizer.load_state_dict(state["optimizer_state_dict"])
|
||||||
|
model.eval()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def canonical_state(game, state, player):
|
||||||
|
board_data, active_board = state
|
||||||
|
return (game.get_canonical_board_data(board_data, player), active_board)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_moves(game, moves):
|
||||||
|
state = game.get_init_board()
|
||||||
|
player = 1
|
||||||
|
for action in moves:
|
||||||
|
next_state = game.get_next_state(state, player, action, verify_move=True)
|
||||||
|
if next_state is False:
|
||||||
|
raise ValueError(f"Illegal move in sequence: {action}")
|
||||||
|
state, player = next_state
|
||||||
|
return state, player
|
||||||
|
|
||||||
|
|
||||||
|
def format_board(board_data):
|
||||||
|
symbols = {1: "X", -1: "O", 0: "."}
|
||||||
|
rows = []
|
||||||
|
for row in range(9):
|
||||||
|
cells = [symbols[int(board_data[row * 9 + col])] for col in range(9)]
|
||||||
|
groups = [" ".join(cells[idx:idx + 3]) for idx in (0, 3, 6)]
|
||||||
|
rows.append(" | ".join(groups))
|
||||||
|
if row in (2, 5):
|
||||||
|
rows.append("-" * 23)
|
||||||
|
return "\n".join(rows)
|
||||||
|
|
||||||
|
|
||||||
|
def top_policy_moves(policy, limit):
|
||||||
|
ranked = np.argsort(policy)[::-1][:limit]
|
||||||
|
return [(int(action), float(policy[action])) for action in ranked]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_moves(text):
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
return [int(part.strip()) for part in text.split(",") if part.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_action(text):
|
||||||
|
raw = text.strip().replace(",", " ").split()
|
||||||
|
if len(raw) == 1:
|
||||||
|
action = int(raw[0])
|
||||||
|
elif len(raw) == 2:
|
||||||
|
row, col = (int(value) for value in raw)
|
||||||
|
if not (0 <= row < 9 and 0 <= col < 9):
|
||||||
|
raise ValueError("Row and column must be in [0, 8].")
|
||||||
|
action = row * 9 + col
|
||||||
|
else:
|
||||||
|
raise ValueError("Enter either a flat move index or 'row col'.")
|
||||||
|
if not (0 <= action < 81):
|
||||||
|
raise ValueError("Move index must be in [0, 80].")
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
def scalar_value(value):
|
||||||
|
return float(np.asarray(value).reshape(-1)[0])
|
||||||
|
|
||||||
|
|
||||||
|
def train_command(args):
|
||||||
|
device = get_device(args.device)
|
||||||
|
game = UltimateTicTacToe()
|
||||||
|
model = build_model(game, device)
|
||||||
|
|
||||||
|
train_args = dict(DEFAULT_ARGS)
|
||||||
|
train_args.update(
|
||||||
|
{
|
||||||
|
"num_simulations": args.num_simulations,
|
||||||
|
"numIters": args.num_iters,
|
||||||
|
"numEps": args.num_eps,
|
||||||
|
"epochs": args.epochs,
|
||||||
|
"batch_size": args.batch_size,
|
||||||
|
"lr": args.lr,
|
||||||
|
"weight_decay": args.weight_decay,
|
||||||
|
"replay_buffer_size": args.replay_buffer_size,
|
||||||
|
"value_loss_weight": args.value_loss_weight,
|
||||||
|
"grad_clip_norm": args.grad_clip_norm,
|
||||||
|
"checkpoint_path": args.checkpoint,
|
||||||
|
"temperature_threshold": args.temperature_threshold,
|
||||||
|
"root_dirichlet_alpha": args.root_dirichlet_alpha,
|
||||||
|
"root_exploration_fraction": args.root_exploration_fraction,
|
||||||
|
"arena_compare_games": args.arena_compare_games,
|
||||||
|
"arena_accept_threshold": args.arena_accept_threshold,
|
||||||
|
"arena_compare_simulations": args.arena_compare_simulations,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = Trainer(game, model, train_args)
|
||||||
|
if args.resume:
|
||||||
|
load_checkpoint(model, args.checkpoint, device, optimizer=trainer.optimizer)
|
||||||
|
trainer.learn()
|
||||||
|
|
||||||
|
|
||||||
|
def eval_command(args):
|
||||||
|
device = get_device(args.device)
|
||||||
|
game = UltimateTicTacToe()
|
||||||
|
model = build_model(game, device)
|
||||||
|
load_checkpoint(model, args.checkpoint, device)
|
||||||
|
|
||||||
|
moves = parse_moves(args.moves)
|
||||||
|
state, player = apply_moves(game, moves)
|
||||||
|
current_state = canonical_state(game, state, player)
|
||||||
|
encoded = game.encode_state(current_state)
|
||||||
|
policy, value = model.predict(encoded)
|
||||||
|
legal_mask = np.array(game.get_valid_moves(state), dtype=np.float32)
|
||||||
|
policy = policy * legal_mask
|
||||||
|
if policy.sum() > 0:
|
||||||
|
policy = policy / policy.sum()
|
||||||
|
|
||||||
|
print("Board:")
|
||||||
|
print(format_board(state[0]))
|
||||||
|
print()
|
||||||
|
print(f"Side to move: {'X' if player == 1 else 'O'}")
|
||||||
|
print(f"Active small board: {state[1]}")
|
||||||
|
print(f"Model value: {scalar_value(value):.4f}")
|
||||||
|
print("Top policy moves:")
|
||||||
|
for action, prob in top_policy_moves(policy, args.top_k):
|
||||||
|
print(f" {action:2d} -> {prob:.4f}")
|
||||||
|
|
||||||
|
if args.with_mcts:
|
||||||
|
mcts_args = dict(DEFAULT_ARGS)
|
||||||
|
mcts_args.update(
|
||||||
|
{
|
||||||
|
"num_simulations": args.num_simulations,
|
||||||
|
"root_dirichlet_alpha": None,
|
||||||
|
"root_exploration_fraction": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
root = MCTS(game, model, mcts_args).run(model, current_state, to_play=1)
|
||||||
|
action = root.select_action(temperature=0)
|
||||||
|
print(f"MCTS best move: {action}")
|
||||||
|
|
||||||
|
|
||||||
|
def ai_action(game, model, state, player, num_simulations):
|
||||||
|
current_state = canonical_state(game, state, player)
|
||||||
|
mcts_args = dict(DEFAULT_ARGS)
|
||||||
|
mcts_args.update(
|
||||||
|
{
|
||||||
|
"num_simulations": num_simulations,
|
||||||
|
"root_dirichlet_alpha": None,
|
||||||
|
"root_exploration_fraction": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
root = MCTS(game, model, mcts_args).run(model, current_state, to_play=1)
|
||||||
|
return root.select_action(temperature=0)
|
||||||
|
|
||||||
|
|
||||||
|
def random_action(game, state):
|
||||||
|
legal_actions = [index for index, allowed in enumerate(game.get_valid_moves(state)) if allowed]
|
||||||
|
if not legal_actions:
|
||||||
|
raise ValueError("No legal actions available.")
|
||||||
|
return int(np.random.choice(legal_actions))
|
||||||
|
|
||||||
|
|
||||||
|
def load_player_model(game, checkpoint, device):
|
||||||
|
model = build_model(game, device)
|
||||||
|
load_checkpoint(model, checkpoint, device)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def choose_action(game, player_kind, model, state, player, num_simulations):
|
||||||
|
if player_kind == "random":
|
||||||
|
return random_action(game, state)
|
||||||
|
return ai_action(game, model, state, player, num_simulations)
|
||||||
|
|
||||||
|
|
||||||
|
def play_match(game, x_kind, x_model, o_kind, o_model, num_simulations):
|
||||||
|
state = game.get_init_board()
|
||||||
|
player = 1
|
||||||
|
|
||||||
|
while True:
|
||||||
|
reward = game.get_reward_for_player(state, player)
|
||||||
|
if reward is not None:
|
||||||
|
if reward == 0:
|
||||||
|
return 0
|
||||||
|
return player if reward == 1 else -player
|
||||||
|
|
||||||
|
if player == 1:
|
||||||
|
action = choose_action(game, x_kind, x_model, state, player, num_simulations)
|
||||||
|
else:
|
||||||
|
action = choose_action(game, o_kind, o_model, state, player, num_simulations)
|
||||||
|
state, player = game.get_next_state(state, player, action)
|
||||||
|
|
||||||
|
|
||||||
|
def arena_command(args):
|
||||||
|
device = get_device(args.device)
|
||||||
|
game = UltimateTicTacToe()
|
||||||
|
|
||||||
|
x_model = None
|
||||||
|
o_model = None
|
||||||
|
if args.x_player == "checkpoint":
|
||||||
|
x_model = load_player_model(game, args.x_checkpoint, device)
|
||||||
|
if args.o_player == "checkpoint":
|
||||||
|
o_model = load_player_model(game, args.o_checkpoint, device)
|
||||||
|
|
||||||
|
results = {1: 0, -1: 0, 0: 0}
|
||||||
|
for _ in range(args.games):
|
||||||
|
winner = play_match(
|
||||||
|
game,
|
||||||
|
args.x_player,
|
||||||
|
x_model,
|
||||||
|
args.o_player,
|
||||||
|
o_model,
|
||||||
|
args.num_simulations,
|
||||||
|
)
|
||||||
|
results[winner] += 1
|
||||||
|
|
||||||
|
print(f"Games: {args.games}")
|
||||||
|
print(f"X ({args.x_player}) wins: {results[1]}")
|
||||||
|
print(f"O ({args.o_player}) wins: {results[-1]}")
|
||||||
|
print(f"Draws: {results[0]}")
|
||||||
|
|
||||||
|
|
||||||
|
def play_command(args):
|
||||||
|
device = get_device(args.device)
|
||||||
|
game = UltimateTicTacToe()
|
||||||
|
model = build_model(game, device)
|
||||||
|
load_checkpoint(model, args.checkpoint, device)
|
||||||
|
|
||||||
|
state = game.get_init_board()
|
||||||
|
player = 1
|
||||||
|
human_player = args.human_player
|
||||||
|
|
||||||
|
while True:
|
||||||
|
print()
|
||||||
|
print(format_board(state[0]))
|
||||||
|
print(f"Turn: {'X' if player == 1 else 'O'}")
|
||||||
|
print(f"Active small board: {state[1]}")
|
||||||
|
|
||||||
|
reward = game.get_reward_for_player(state, player)
|
||||||
|
if reward is not None:
|
||||||
|
if reward == 0:
|
||||||
|
print("Result: draw")
|
||||||
|
else:
|
||||||
|
winner = player if reward == 1 else -player
|
||||||
|
print(f"Winner: {'X' if winner == 1 else 'O'}")
|
||||||
|
return
|
||||||
|
|
||||||
|
valid_moves = game.get_valid_moves(state)
|
||||||
|
legal_actions = [index for index, allowed in enumerate(valid_moves) if allowed]
|
||||||
|
print(f"Legal moves: {legal_actions}")
|
||||||
|
|
||||||
|
if player == human_player:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
action = parse_action(input("Your move (index or 'row col'): "))
|
||||||
|
next_state = game.get_next_state(state, player, action, verify_move=True)
|
||||||
|
if next_state is False:
|
||||||
|
raise ValueError(f"Illegal move: {action}")
|
||||||
|
state, player = next_state
|
||||||
|
break
|
||||||
|
except ValueError as exc:
|
||||||
|
print(exc)
|
||||||
|
else:
|
||||||
|
action = ai_action(game, model, state, player, args.num_simulations)
|
||||||
|
print(f"AI move: {action}")
|
||||||
|
state, player = game.get_next_state(state, player, action)
|
||||||
|
|
||||||
|
|
||||||
|
def build_parser():
|
||||||
|
parser = argparse.ArgumentParser(description="Ultimate Tic-Tac-Toe Runner")
|
||||||
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||||
|
|
||||||
|
train_parser = subparsers.add_parser("train", help="Train the model with self-play")
|
||||||
|
train_parser.add_argument("--device")
|
||||||
|
train_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
|
||||||
|
train_parser.add_argument("--resume", action="store_true")
|
||||||
|
train_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
|
||||||
|
train_parser.add_argument("--num-iters", type=int, default=DEFAULT_ARGS["numIters"])
|
||||||
|
train_parser.add_argument("--num-eps", type=int, default=DEFAULT_ARGS["numEps"])
|
||||||
|
train_parser.add_argument("--epochs", type=int, default=DEFAULT_ARGS["epochs"])
|
||||||
|
train_parser.add_argument("--batch-size", type=int, default=DEFAULT_ARGS["batch_size"])
|
||||||
|
train_parser.add_argument("--lr", type=float, default=DEFAULT_ARGS["lr"])
|
||||||
|
train_parser.add_argument("--weight-decay", type=float, default=DEFAULT_ARGS["weight_decay"])
|
||||||
|
train_parser.add_argument("--replay-buffer-size", type=int, default=DEFAULT_ARGS["replay_buffer_size"])
|
||||||
|
train_parser.add_argument("--value-loss-weight", type=float, default=DEFAULT_ARGS["value_loss_weight"])
|
||||||
|
train_parser.add_argument("--grad-clip-norm", type=float, default=DEFAULT_ARGS["grad_clip_norm"])
|
||||||
|
train_parser.add_argument("--temperature-threshold", type=int, default=DEFAULT_ARGS["temperature_threshold"])
|
||||||
|
train_parser.add_argument("--root-dirichlet-alpha", type=float, default=DEFAULT_ARGS["root_dirichlet_alpha"])
|
||||||
|
train_parser.add_argument("--root-exploration-fraction", type=float, default=DEFAULT_ARGS["root_exploration_fraction"])
|
||||||
|
train_parser.add_argument("--arena-compare-games", type=int, default=DEFAULT_ARGS["arena_compare_games"])
|
||||||
|
train_parser.add_argument("--arena-accept-threshold", type=float, default=DEFAULT_ARGS["arena_accept_threshold"])
|
||||||
|
train_parser.add_argument("--arena-compare-simulations", type=int, default=DEFAULT_ARGS["arena_compare_simulations"])
|
||||||
|
train_parser.set_defaults(func=train_command)
|
||||||
|
|
||||||
|
eval_parser = subparsers.add_parser("eval", help="Inspect a checkpoint on a position")
|
||||||
|
eval_parser.add_argument("--device")
|
||||||
|
eval_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
|
||||||
|
eval_parser.add_argument("--moves", default="", help="Comma-separated move sequence")
|
||||||
|
eval_parser.add_argument("--top-k", type=int, default=10)
|
||||||
|
eval_parser.add_argument("--with-mcts", action="store_true")
|
||||||
|
eval_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
|
||||||
|
eval_parser.set_defaults(func=eval_command)
|
||||||
|
|
||||||
|
play_parser = subparsers.add_parser("play", help="Play against the checkpoint")
|
||||||
|
play_parser.add_argument("--device")
|
||||||
|
play_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
|
||||||
|
play_parser.add_argument("--human-player", type=int, choices=[1, -1], default=1)
|
||||||
|
play_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
|
||||||
|
play_parser.set_defaults(func=play_command)
|
||||||
|
|
||||||
|
arena_parser = subparsers.add_parser("arena", help="Run repeated matches between agents")
|
||||||
|
arena_parser.add_argument("--device")
|
||||||
|
arena_parser.add_argument("--games", type=int, default=20)
|
||||||
|
arena_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
|
||||||
|
arena_parser.add_argument("--x-player", choices=["checkpoint", "random"], default="checkpoint")
|
||||||
|
arena_parser.add_argument("--o-player", choices=["checkpoint", "random"], default="random")
|
||||||
|
arena_parser.add_argument("--x-checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
|
||||||
|
arena_parser.add_argument("--o-checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
|
||||||
|
arena_parser.set_defaults(func=arena_command)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = build_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.func(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
258
trainer.py
Normal file
258
trainer.py
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
import os
|
||||||
|
import math
|
||||||
|
import copy
|
||||||
|
import numpy as np
|
||||||
|
from collections import deque
|
||||||
|
from random import shuffle
|
||||||
|
from progressbar import ProgressBar, Percentage, Bar, ETA, AdaptiveETA, FormatLabel
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
|
from mcts import MCTS
|
||||||
|
|
||||||
|
loss = {'ploss': float('inf'), 'pkl': float('inf'), 'vloss': float('inf'), 'current': 0, 'max': 0}
|
||||||
|
|
||||||
|
class LossWidget:
|
||||||
|
def __call__(self, progress, data=None):
|
||||||
|
return (
|
||||||
|
f" {loss['current']}/{loss['max']} "
|
||||||
|
f"P.CE: {loss['ploss']:.4f}, P.KL: {loss['pkl']:.4f}, V.Loss: {loss['vloss']:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
|
||||||
|
def __init__(self, game, model, args):
|
||||||
|
self.game = game
|
||||||
|
self.model = model
|
||||||
|
self.args = args
|
||||||
|
self.mcts = MCTS(self.game, self.model, self.args)
|
||||||
|
self.optimizer = optim.Adam(
|
||||||
|
self.model.parameters(),
|
||||||
|
lr=self.args.get('lr', 5e-4),
|
||||||
|
weight_decay=self.args.get('weight_decay', 1e-4),
|
||||||
|
)
|
||||||
|
self.replay_buffer = deque(maxlen=self.args.get('replay_buffer_size', 50000))
|
||||||
|
|
||||||
|
def _ai_action(self, model, state, player):
|
||||||
|
board_data, active_board = state
|
||||||
|
canonical_state = (self.game.get_canonical_board_data(board_data, player), active_board)
|
||||||
|
mcts_args = dict(self.args)
|
||||||
|
mcts_args.update(
|
||||||
|
{
|
||||||
|
'num_simulations': self.args.get('arena_compare_simulations', self.args['num_simulations']),
|
||||||
|
'root_dirichlet_alpha': None,
|
||||||
|
'root_exploration_fraction': None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
root = MCTS(self.game, model, mcts_args).run(model, canonical_state, to_play=1)
|
||||||
|
return root.select_action(temperature=0)
|
||||||
|
|
||||||
|
def _play_arena_game(self, x_model, o_model):
|
||||||
|
state = self.game.get_init_board()
|
||||||
|
current_player = 1
|
||||||
|
|
||||||
|
while True:
|
||||||
|
reward = self.game.get_reward_for_player(state, current_player)
|
||||||
|
if reward is not None:
|
||||||
|
if reward == 0:
|
||||||
|
return 0
|
||||||
|
return current_player if reward == 1 else -current_player
|
||||||
|
|
||||||
|
model = x_model if current_player == 1 else o_model
|
||||||
|
action = self._ai_action(model, state, current_player)
|
||||||
|
state, current_player = self.game.get_next_state(state, current_player, action)
|
||||||
|
|
||||||
|
def evaluate_candidate(self, candidate_model, reference_model):
|
||||||
|
games = self.args.get('arena_compare_games', 0)
|
||||||
|
if games <= 0:
|
||||||
|
return True, 0.5
|
||||||
|
|
||||||
|
candidate_points = 0.0
|
||||||
|
required_points = self.args.get('arena_accept_threshold', 0.55) * games
|
||||||
|
candidate_first_games = (games + 1) // 2
|
||||||
|
candidate_second_games = games // 2
|
||||||
|
games_played = 0
|
||||||
|
|
||||||
|
for _ in range(candidate_first_games):
|
||||||
|
winner = self._play_arena_game(candidate_model, reference_model)
|
||||||
|
if winner == 1:
|
||||||
|
candidate_points += 1.0
|
||||||
|
elif winner == 0:
|
||||||
|
candidate_points += 0.5
|
||||||
|
games_played += 1
|
||||||
|
remaining_games = games - games_played
|
||||||
|
if candidate_points >= required_points:
|
||||||
|
return True, candidate_points / games
|
||||||
|
if candidate_points + remaining_games < required_points:
|
||||||
|
return False, candidate_points / games
|
||||||
|
|
||||||
|
for _ in range(candidate_second_games):
|
||||||
|
winner = self._play_arena_game(reference_model, candidate_model)
|
||||||
|
if winner == -1:
|
||||||
|
candidate_points += 1.0
|
||||||
|
elif winner == 0:
|
||||||
|
candidate_points += 0.5
|
||||||
|
games_played += 1
|
||||||
|
remaining_games = games - games_played
|
||||||
|
if candidate_points >= required_points:
|
||||||
|
return True, candidate_points / games
|
||||||
|
if candidate_points + remaining_games < required_points:
|
||||||
|
return False, candidate_points / games
|
||||||
|
|
||||||
|
score = candidate_points / games
|
||||||
|
return score >= self.args.get('arena_accept_threshold', 0.55), score
|
||||||
|
|
||||||
|
def exceute_episode(self):
|
||||||
|
|
||||||
|
train_examples = []
|
||||||
|
current_player = 1
|
||||||
|
state = self.game.get_init_board()
|
||||||
|
episode_step = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
board_data, state_inner_board = state
|
||||||
|
cannonical_board_data = self.game.get_canonical_board_data(board_data, current_player)
|
||||||
|
canonical_board = (cannonical_board_data, state_inner_board)
|
||||||
|
|
||||||
|
self.mcts = MCTS(self.game, self.model, self.args)
|
||||||
|
root = self.mcts.run(self.model, canonical_board, to_play=1)
|
||||||
|
|
||||||
|
action_probs = np.zeros(self.game.get_action_size(), dtype=np.float32)
|
||||||
|
for k, v in root.children.items():
|
||||||
|
action_probs[k] = v.visit_count
|
||||||
|
|
||||||
|
action_probs = action_probs / np.sum(action_probs)
|
||||||
|
encoded_state = self.game.encode_state(canonical_board)
|
||||||
|
train_examples.append((encoded_state, current_player, action_probs))
|
||||||
|
|
||||||
|
temperature_threshold = self.args.get('temperature_threshold', 10)
|
||||||
|
temperature = 1 if episode_step < temperature_threshold else 0
|
||||||
|
action = root.select_action(temperature=temperature)
|
||||||
|
state, current_player = self.game.get_next_state(state, current_player, action)
|
||||||
|
reward = self.game.get_reward_for_player(state, current_player)
|
||||||
|
episode_step += 1
|
||||||
|
|
||||||
|
if reward is not None:
|
||||||
|
ret = []
|
||||||
|
for hist_state, hist_current_player, hist_action_probs in train_examples:
|
||||||
|
# [Board, currentPlayer, actionProbabilities, Reward]
|
||||||
|
ret.append((hist_state, hist_action_probs, reward * ((-1) ** (hist_current_player != current_player))))
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def learn(self):
|
||||||
|
widgets = [Percentage(), Bar(), AdaptiveETA(), LossWidget()]
|
||||||
|
pbar = ProgressBar(max_value=self.args['numIters'], widgets=widgets)
|
||||||
|
pbar.update(0)
|
||||||
|
for i in range(1, self.args['numIters'] + 1):
|
||||||
|
|
||||||
|
# print("{}/{}".format(i, self.args['numIters']))
|
||||||
|
|
||||||
|
train_examples = []
|
||||||
|
|
||||||
|
for eps in range(self.args['numEps']):
|
||||||
|
iteration_train_examples = self.exceute_episode()
|
||||||
|
train_examples.extend(iteration_train_examples)
|
||||||
|
|
||||||
|
self.replay_buffer.extend(train_examples)
|
||||||
|
training_examples = list(self.replay_buffer)
|
||||||
|
shuffle(training_examples)
|
||||||
|
shuffle(train_examples)
|
||||||
|
|
||||||
|
reference_model = copy.deepcopy(self.model)
|
||||||
|
reference_model.eval()
|
||||||
|
reference_state_dict = copy.deepcopy(self.model.state_dict())
|
||||||
|
reference_optimizer_state = copy.deepcopy(self.optimizer.state_dict())
|
||||||
|
|
||||||
|
results = self.train(training_examples)
|
||||||
|
accepted, arena_score = self.evaluate_candidate(self.model, reference_model)
|
||||||
|
if accepted:
|
||||||
|
filename = self.args['checkpoint_path']
|
||||||
|
self.save_checkpoint(folder=".", filename=filename)
|
||||||
|
else:
|
||||||
|
self.model.load_state_dict(reference_state_dict)
|
||||||
|
self.optimizer.load_state_dict(reference_optimizer_state)
|
||||||
|
|
||||||
|
# print((float(results[0]), float(results[1])))
|
||||||
|
loss['ploss'] = float(results[0])
|
||||||
|
loss['pkl'] = float(results[1])
|
||||||
|
loss['vloss'] = float(results[2])
|
||||||
|
loss['current'] = i
|
||||||
|
loss['max'] = self.args['numIters']
|
||||||
|
# loss_widget.update(float(results[0]), float(results[1]))
|
||||||
|
pbar.update(i)
|
||||||
|
pbar.finish()
|
||||||
|
|
||||||
|
def train(self, examples):
|
||||||
|
pi_losses = []
|
||||||
|
pi_kls = []
|
||||||
|
v_losses = []
|
||||||
|
device = self.model.device
|
||||||
|
|
||||||
|
for epoch in range(self.args['epochs']):
|
||||||
|
self.model.train()
|
||||||
|
shuffled_indices = np.random.permutation(len(examples))
|
||||||
|
|
||||||
|
for batch_start in range(0, len(examples), self.args['batch_size']):
|
||||||
|
sample_ids = shuffled_indices[batch_start:batch_start + self.args['batch_size']]
|
||||||
|
boards, pis, vs = list(zip(*[examples[i] for i in sample_ids]))
|
||||||
|
boards = torch.FloatTensor(np.array(boards).astype(np.float32))
|
||||||
|
target_pis = torch.FloatTensor(np.array(pis))
|
||||||
|
target_vs = torch.FloatTensor(np.array(vs).astype(np.float32))
|
||||||
|
|
||||||
|
# predict
|
||||||
|
boards = boards.contiguous().to(device)
|
||||||
|
target_pis = target_pis.contiguous().to(device)
|
||||||
|
target_vs = target_vs.contiguous().to(device)
|
||||||
|
|
||||||
|
# compute output
|
||||||
|
out_pi, out_v = self.model(boards)
|
||||||
|
l_pi = self.loss_pi(target_pis, out_pi)
|
||||||
|
l_pi_kl = self.loss_pi_kl(target_pis, out_pi)
|
||||||
|
l_v = self.loss_v(target_vs, out_v)
|
||||||
|
total_loss = l_pi_kl + self.args.get('value_loss_weight', 1.0) * l_v
|
||||||
|
|
||||||
|
pi_losses.append(float(l_pi.detach()))
|
||||||
|
pi_kls.append(float(l_pi_kl.detach()))
|
||||||
|
v_losses.append(float(l_v.detach()))
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
total_loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.get('grad_clip_norm', 5.0))
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
# print()
|
||||||
|
# print("Policy Loss", np.mean(pi_losses))
|
||||||
|
# print("Value Loss", np.mean(v_losses))
|
||||||
|
|
||||||
|
return (np.mean(pi_losses), np.mean(pi_kls), np.mean(v_losses))
|
||||||
|
# print("Examples:")
|
||||||
|
# print(out_pi[0].detach())
|
||||||
|
# print(target_pis[0])
|
||||||
|
|
||||||
|
def loss_pi(self, targets, outputs):
|
||||||
|
loss = -(targets * torch.log(outputs.clamp_min(1e-8))).sum(dim=1)
|
||||||
|
return loss.mean()
|
||||||
|
|
||||||
|
def loss_pi_kl(self, targets, outputs):
|
||||||
|
target_log = torch.log(targets.clamp_min(1e-8))
|
||||||
|
output_log = torch.log(outputs.clamp_min(1e-8))
|
||||||
|
loss = (targets * (target_log - output_log)).sum(dim=1)
|
||||||
|
return loss.mean()
|
||||||
|
|
||||||
|
def loss_v(self, targets, outputs):
|
||||||
|
loss = torch.sum((targets-outputs.view(-1))**2)/targets.size()[0]
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def save_checkpoint(self, folder, filename):
|
||||||
|
if not os.path.exists(folder):
|
||||||
|
os.mkdir(folder)
|
||||||
|
|
||||||
|
filepath = os.path.join(folder, filename)
|
||||||
|
torch.save({
|
||||||
|
'state_dict': self.model.state_dict(),
|
||||||
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||||
|
'args': self.args,
|
||||||
|
}, filepath)
|
||||||
Reference in New Issue
Block a user