First commit
This commit is contained in:
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
__pycache__/
|
||||
.ipynb_checkpoints/
|
||||
*.pth
|
||||
|
||||
60
README.md
Normal file
60
README.md
Normal file
@@ -0,0 +1,60 @@
|
||||
# Ultimate Tic Tac Toe Deep Learning Bot
|
||||
|
||||
**Usage**
|
||||
Run `python run.py --help` for help.
|
||||
|
||||
**Shared flags**
|
||||
- --device: Torch device string. If omitted, it auto-picks "cuda" when available, otherwise "cpu".
|
||||
- --checkpoint: Path to the model checkpoint file. Default is latest.pth. It is loaded for eval, play, and checkpoint-based arena,
|
||||
and used as the save path for accepted training checkpoints.
|
||||
|
||||
**Training parameters**
|
||||
- --resume: Loads model and optimizer state from --checkpoint before continuing training.
|
||||
- --num-simulations default 100: MCTS rollouts per move during self-play. Higher is stronger/slower.
|
||||
- --num-iters default 50: Number of outer training iterations. Each iteration generates new self-play games, trains, then arena-tests
|
||||
the new model.
|
||||
- --num-eps default 20: Self-play games per iteration.
|
||||
- --epochs default 5: Passes over the current replay-buffer training set per iteration.
|
||||
- --batch-size default 64: Mini-batch size for gradient updates.
|
||||
- --lr default 5e-4: Adam learning rate.
|
||||
- --weight-decay default 1e-4: Adam weight decay (L2-style regularization).
|
||||
- --replay-buffer-size default 50000: Maximum number of training examples retained across iterations. Older examples are dropped.
|
||||
- --value-loss-weight default 1.0: Multiplier on the value-head loss in total training loss. Total loss is policy_KL +
|
||||
value_loss_weight * value_loss.
|
||||
- --grad-clip-norm default 5.0: Global gradient norm clipping threshold before optimizer step.
|
||||
- --temperature-threshold default 10: In self-play, moves before this step use stochastic sampling from MCTS visit counts; later
|
||||
moves use greedy selection.
|
||||
- --root-dirichlet-alpha default 0.3: Dirichlet noise alpha added to root priors during self-play MCTS to force exploration.
|
||||
- --root-exploration-fraction default 0.25: How much of that root prior is replaced by Dirichlet noise.
|
||||
- --arena-compare-games default 6: Number of head-to-head games between candidate and previous model after each iteration. If <= 0,
|
||||
every candidate is accepted.
|
||||
- --arena-accept-threshold default 0.55: Minimum average points needed in arena to keep the new model. Win = 1, draw = 0.5.
|
||||
- --arena-compare-simulations default 8: MCTS simulations per move during those arena comparison games. Separate from self-play
|
||||
--num-simulations.
|
||||
|
||||
**Evaluation parameters**
|
||||
|
||||
- --moves default "": Comma-separated move list to reach a position from the starting board, e.g. 0,10,4.
|
||||
- --top-k default 10: How many highest-probability legal moves to print from the model policy.
|
||||
- --with-mcts: Also run MCTS on that position and print the best move, instead of only raw network policy/value.
|
||||
- --num-simulations default 100: Only matters with --with-mcts; controls MCTS search depth for that evaluation.
|
||||
|
||||
**Play parameters**
|
||||
|
||||
- --human-player default 1: Which side you control. 1 means X, -1 means O.
|
||||
- --num-simulations default 100: MCTS simulations the AI uses for each move.
|
||||
|
||||
**Arena parameters**
|
||||
- --games default 20: Number of matches to run.
|
||||
- --num-simulations default 100: MCTS simulations per move for checkpoint-based players.
|
||||
- --x-player / --o-player: Either checkpoint or random. Chooses the agent type for each side.
|
||||
- --x-checkpoint / --o-checkpoint: Checkpoint path for that side when its player type is checkpoint. Ignored for random.
|
||||
|
||||
A few practical examples:
|
||||
|
||||
```bash
|
||||
python run.py train --num-iters 100 --num-eps 50 --resume
|
||||
python run.py eval --checkpoint latest.pth --moves 0,10,4 --with-mcts --num-simulations 200
|
||||
python run.py play --human-player -1 --num-simulations 300
|
||||
python run.py arena --games 50 --x-player checkpoint --o-player random
|
||||
```
|
||||
196
game.py
Normal file
196
game.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import numpy as np
|
||||
|
||||
WIN_PATTERNS = [
|
||||
(0, 1, 2),
|
||||
(3, 4, 5),
|
||||
(6, 7, 8),
|
||||
(0, 3, 6),
|
||||
(1, 4, 7),
|
||||
(2, 5, 8),
|
||||
(0, 4, 8),
|
||||
(2, 4, 6),
|
||||
]
|
||||
|
||||
class UltimateTicTacToe:
|
||||
"""
|
||||
A very, very simple game of ConnectX in which we have:
|
||||
rows: 1
|
||||
columns: 4
|
||||
winNumber: 2
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.cells = 81
|
||||
self.board_width = 9
|
||||
self.state_planes = 9
|
||||
|
||||
def get_init_board(self):
|
||||
b = np.zeros((self.cells,), dtype=int)
|
||||
return (b, None)
|
||||
|
||||
def get_board_size(self):
|
||||
return (self.state_planes, self.board_width, self.board_width)
|
||||
|
||||
def get_action_size(self):
|
||||
return self.cells
|
||||
|
||||
def get_next_state(self, board, player, action, verify_move=False):
|
||||
if verify_move:
|
||||
if self.get_valid_moves(board)[action] == 0:
|
||||
return False
|
||||
new_board_data = np.copy(board[0])
|
||||
new_board_data[action] = player
|
||||
|
||||
next_board = ((action // 9) % 3) * 3 + (action % 3)
|
||||
next_board = next_board if not self.is_board_full(new_board_data, next_board) else None
|
||||
|
||||
# Return the new game, but
|
||||
# change the perspective of the game with negative
|
||||
return ((new_board_data, next_board), -player)
|
||||
|
||||
def is_board_full(self, board_data, next_board):
|
||||
return self._is_small_board_win(board_data, next_board, 1) or self._is_small_board_win(board_data, next_board, -1) or self._is_board_full(board_data, next_board)
|
||||
|
||||
def _small_board_cells(self, inner_board_idx):
|
||||
row_block = inner_board_idx // 3
|
||||
col_block = inner_board_idx % 3
|
||||
|
||||
base = row_block * 27 + col_block * 3
|
||||
|
||||
return [
|
||||
base, base + 1, base + 2,
|
||||
base + 9, base + 10, base + 11,
|
||||
base + 18, base + 19, base + 20
|
||||
]
|
||||
|
||||
def _is_board_full(self, board_data, next_board):
|
||||
# Check if it is literally full
|
||||
cells = self._small_board_cells(next_board)
|
||||
|
||||
for a in cells:
|
||||
if board_data[a] == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _is_playable_small_board(self, board_data, inner_board_idx):
|
||||
return not self.is_board_full(board_data, inner_board_idx)
|
||||
|
||||
def has_legal_moves(self, board):
|
||||
valid_moves = self.get_valid_moves(board)
|
||||
for i in valid_moves:
|
||||
if i == 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_valid_moves(self, board):
|
||||
# All moves are invalid by default
|
||||
board_data, active_board = board
|
||||
valid_moves = [0] * self.get_action_size()
|
||||
|
||||
if active_board is not None and not self._is_playable_small_board(board_data, active_board):
|
||||
active_board = None
|
||||
|
||||
if active_board is None:
|
||||
playable_boards = [
|
||||
inner_board_idx
|
||||
for inner_board_idx in range(9)
|
||||
if self._is_playable_small_board(board_data, inner_board_idx)
|
||||
]
|
||||
for inner_board_idx in playable_boards:
|
||||
for index in self._small_board_cells(inner_board_idx):
|
||||
if board_data[index] == 0:
|
||||
valid_moves[index] = 1
|
||||
else:
|
||||
for index in self._small_board_cells(active_board):
|
||||
if board_data[index] == 0:
|
||||
valid_moves[index] = 1
|
||||
|
||||
return valid_moves
|
||||
|
||||
def _is_small_board_win(self, board_data, inner_board_idx, player):
|
||||
cells = self._small_board_cells(inner_board_idx)
|
||||
|
||||
for a, b, c in WIN_PATTERNS:
|
||||
if board_data[cells[a]] == board_data[cells[b]] == board_data[cells[c]] == player:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def is_win(self, board, player):
|
||||
board_data, _ = board
|
||||
won = [self._is_small_board_win(board_data, i, player) for i in range(9)]
|
||||
|
||||
# Check if any winning combination is all 1s
|
||||
for a, b, c in WIN_PATTERNS:
|
||||
if won[a] and won[b] and won[c]:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_reward_for_player(self, board, player):
|
||||
# return None if not ended, 1 if player 1 wins, -1 if player 1 lost
|
||||
|
||||
if self.is_win(board, player):
|
||||
return 1
|
||||
if self.is_win(board, -player):
|
||||
return -1
|
||||
if self.has_legal_moves(board):
|
||||
return None
|
||||
|
||||
return 0
|
||||
|
||||
def get_canonical_board_data(self, board_data, player):
|
||||
return player * board_data
|
||||
|
||||
def _small_board_mask(self, inner_board_idx):
|
||||
mask = np.zeros((self.board_width, self.board_width), dtype=np.float32)
|
||||
for index in self._small_board_cells(inner_board_idx):
|
||||
row = index // self.board_width
|
||||
col = index % self.board_width
|
||||
mask[row, col] = 1.0
|
||||
return mask
|
||||
|
||||
def encode_state(self, board):
|
||||
board_data, active_board = board
|
||||
board_grid = board_data.reshape(self.board_width, self.board_width)
|
||||
|
||||
current_stones = (board_grid == 1).astype(np.float32)
|
||||
opponent_stones = (board_grid == -1).astype(np.float32)
|
||||
empty_cells = (board_grid == 0).astype(np.float32)
|
||||
legal_moves = np.array(self.get_valid_moves(board), dtype=np.float32).reshape(self.board_width, self.board_width)
|
||||
|
||||
active_board_mask = np.zeros((self.board_width, self.board_width), dtype=np.float32)
|
||||
if active_board is not None and self._is_playable_small_board(board_data, active_board):
|
||||
active_board_mask = self._small_board_mask(active_board)
|
||||
|
||||
current_won_boards = np.zeros((self.board_width, self.board_width), dtype=np.float32)
|
||||
opponent_won_boards = np.zeros((self.board_width, self.board_width), dtype=np.float32)
|
||||
playable_boards = np.zeros((self.board_width, self.board_width), dtype=np.float32)
|
||||
|
||||
for inner_board_idx in range(9):
|
||||
board_mask = self._small_board_mask(inner_board_idx)
|
||||
if self._is_small_board_win(board_data, inner_board_idx, 1):
|
||||
current_won_boards += board_mask
|
||||
elif self._is_small_board_win(board_data, inner_board_idx, -1):
|
||||
opponent_won_boards += board_mask
|
||||
|
||||
if self._is_playable_small_board(board_data, inner_board_idx):
|
||||
playable_boards += board_mask
|
||||
|
||||
move_count = np.count_nonzero(board_data) / self.cells
|
||||
move_count_plane = np.full((self.board_width, self.board_width), move_count, dtype=np.float32)
|
||||
|
||||
return np.stack(
|
||||
(
|
||||
current_stones,
|
||||
opponent_stones,
|
||||
empty_cells,
|
||||
legal_moves,
|
||||
active_board_mask,
|
||||
current_won_boards,
|
||||
opponent_won_boards,
|
||||
playable_boards,
|
||||
move_count_plane,
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
190
mcts.py
Normal file
190
mcts.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def ucb_score(parent, child):
|
||||
"""
|
||||
The score for an action that would transition between the parent and child.
|
||||
"""
|
||||
prior_score = child.prior * math.sqrt(parent.visit_count) / (child.visit_count + 1)
|
||||
if child.visit_count > 0:
|
||||
# The value of the child is from the perspective of the opposing player
|
||||
value_score = -child.value()
|
||||
else:
|
||||
value_score = 0
|
||||
|
||||
return value_score + prior_score
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self, prior, to_play):
|
||||
self.visit_count = 0
|
||||
self.to_play = to_play
|
||||
self.prior = prior
|
||||
self.value_sum = 0
|
||||
self.children = {}
|
||||
self.state = None
|
||||
|
||||
def expanded(self):
|
||||
return len(self.children) > 0
|
||||
|
||||
def value(self):
|
||||
if self.visit_count == 0:
|
||||
return 0
|
||||
return self.value_sum / self.visit_count
|
||||
|
||||
def select_action(self, temperature):
|
||||
"""
|
||||
Select action according to the visit count distribution and the temperature.
|
||||
"""
|
||||
visit_counts = np.array([child.visit_count for child in self.children.values()])
|
||||
actions = [action for action in self.children.keys()]
|
||||
if temperature == 0:
|
||||
action = actions[np.argmax(visit_counts)]
|
||||
elif temperature == float("inf"):
|
||||
action = np.random.choice(actions)
|
||||
else:
|
||||
# See paper appendix Data Generation
|
||||
visit_count_distribution = visit_counts ** (1 / temperature)
|
||||
visit_count_distribution = visit_count_distribution / sum(visit_count_distribution)
|
||||
action = np.random.choice(actions, p=visit_count_distribution)
|
||||
|
||||
return action
|
||||
|
||||
def select_child(self):
|
||||
"""
|
||||
Select the child with the highest UCB score.
|
||||
"""
|
||||
best_score = -np.inf
|
||||
best_action = -1
|
||||
best_child = None
|
||||
|
||||
for action, child in self.children.items():
|
||||
score = ucb_score(self, child)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_action = action
|
||||
best_child = child
|
||||
|
||||
return best_action, best_child
|
||||
|
||||
def expand(self, state, to_play, action_probs):
|
||||
"""
|
||||
We expand a node and keep track of the prior policy probability given by neural network
|
||||
"""
|
||||
self.to_play = to_play
|
||||
self.state = state
|
||||
for a, prob in enumerate(action_probs):
|
||||
if prob != 0:
|
||||
self.children[a] = Node(prior=prob, to_play=self.to_play * -1)
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Debugger pretty print node info
|
||||
"""
|
||||
prior = "{0:.2f}".format(self.prior)
|
||||
return "{} Prior: {} Count: {} Value: {}".format(self.state.__str__(), prior, self.visit_count, self.value())
|
||||
|
||||
|
||||
class MCTS:
|
||||
|
||||
def __init__(self, game, model, args):
|
||||
self.game = game
|
||||
self.model = model
|
||||
self.args = args
|
||||
|
||||
def _masked_policy(self, state, model):
|
||||
encoded_state = self.game.encode_state(state)
|
||||
action_probs, value = model.predict(encoded_state)
|
||||
valid_moves = np.array(self.game.get_valid_moves(state), dtype=np.float32)
|
||||
action_probs = action_probs * valid_moves
|
||||
total_prob = np.sum(action_probs)
|
||||
total_valid = np.sum(valid_moves)
|
||||
if total_valid <= 0:
|
||||
return valid_moves, float(value)
|
||||
if total_prob <= 0:
|
||||
action_probs = valid_moves / total_valid
|
||||
else:
|
||||
action_probs /= total_prob
|
||||
return action_probs, float(value)
|
||||
|
||||
def _add_exploration_noise(self, node):
|
||||
alpha = self.args.get('root_dirichlet_alpha')
|
||||
fraction = self.args.get('root_exploration_fraction')
|
||||
if alpha is None or fraction is None or not node.children:
|
||||
return
|
||||
|
||||
actions = list(node.children.keys())
|
||||
noise = np.random.dirichlet([alpha] * len(actions))
|
||||
for action, sample in zip(actions, noise):
|
||||
child = node.children[action]
|
||||
child.prior = child.prior * (1 - fraction) + sample * fraction
|
||||
|
||||
def run(self, model, state, to_play):
|
||||
model = model or self.model
|
||||
|
||||
root = Node(0, to_play)
|
||||
root.state = state
|
||||
|
||||
reward = self.game.get_reward_for_player(state, player=1)
|
||||
if reward is not None:
|
||||
root.value_sum = float(reward)
|
||||
root.visit_count = 1
|
||||
return root
|
||||
|
||||
# EXPAND root
|
||||
action_probs, value = self._masked_policy(state, model)
|
||||
root.expand(state, to_play, action_probs)
|
||||
if not root.children:
|
||||
root.value_sum = float(value)
|
||||
root.visit_count = 1
|
||||
return root
|
||||
self._add_exploration_noise(root)
|
||||
|
||||
for _ in range(self.args['num_simulations']):
|
||||
node = root
|
||||
search_path = [node]
|
||||
|
||||
# SELECT
|
||||
xp = False
|
||||
while node.expanded():
|
||||
action, node = node.select_child()
|
||||
if node == None:
|
||||
parent = search_path[-1]
|
||||
self.backpropagate(search_path, value, parent.to_play * -1)
|
||||
xp = True
|
||||
break
|
||||
search_path.append(node)
|
||||
if xp:
|
||||
continue
|
||||
parent = search_path[-2]
|
||||
state = parent.state
|
||||
# Now we're at a leaf node and we would like to expand
|
||||
# Players always play from their own perspective
|
||||
next_state, _ = self.game.get_next_state(state, player=1, action=action)
|
||||
# Get the board from the perspective of the other player
|
||||
next_state_data, next_state_inner = next_state
|
||||
|
||||
next_state = (self.game.get_canonical_board_data(next_state_data, player=-1), next_state_inner)
|
||||
|
||||
# The value of the new state from the perspective of the other player
|
||||
value = self.game.get_reward_for_player(next_state, player=1)
|
||||
if value is None:
|
||||
# If the game has not ended:
|
||||
# EXPAND
|
||||
action_probs, value = self._masked_policy(next_state, model)
|
||||
node.expand(next_state, parent.to_play * -1, action_probs)
|
||||
|
||||
self.backpropagate(search_path, float(value), parent.to_play * -1)
|
||||
|
||||
return root
|
||||
|
||||
def backpropagate(self, search_path, value, to_play):
|
||||
"""
|
||||
At the end of a simulation, we propagate the evaluation all the way up the tree
|
||||
to the root.
|
||||
"""
|
||||
for node in reversed(search_path):
|
||||
node.value_sum += value if node.to_play == to_play else -value
|
||||
node.visit_count += 1
|
||||
81
model.py
Normal file
81
model.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(channels)
|
||||
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(channels)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = F.relu(self.bn1(self.conv1(x)))
|
||||
x = self.bn2(self.conv2(x))
|
||||
return F.relu(x + residual)
|
||||
|
||||
|
||||
class UltimateTicTacToeModel(nn.Module):
|
||||
def __init__(self, board_size, action_size, device, channels=64, num_blocks=6):
|
||||
super().__init__()
|
||||
|
||||
self.action_size = action_size
|
||||
self.input_shape = board_size
|
||||
self.input_channels = board_size[0]
|
||||
self.board_height = board_size[1]
|
||||
self.board_width = board_size[2]
|
||||
self.device = torch.device(device)
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(self.input_channels, channels, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.residual_tower = nn.Sequential(*(ResidualBlock(channels) for _ in range(num_blocks)))
|
||||
|
||||
self.policy_head = nn.Sequential(
|
||||
nn.Conv2d(channels, 32, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.policy_fc = nn.Linear(32 * self.board_height * self.board_width, self.action_size)
|
||||
|
||||
self.value_head = nn.Sequential(
|
||||
nn.Conv2d(channels, 32, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.value_fc1 = nn.Linear(32 * self.board_height * self.board_width, 128)
|
||||
self.value_fc2 = nn.Linear(128, 1)
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(-1, *self.input_shape)
|
||||
x = self.stem(x)
|
||||
x = self.residual_tower(x)
|
||||
|
||||
policy = self.policy_head(x)
|
||||
policy = torch.flatten(policy, 1)
|
||||
policy = self.policy_fc(policy)
|
||||
|
||||
value = self.value_head(x)
|
||||
value = torch.flatten(value, 1)
|
||||
value = F.relu(self.value_fc1(value))
|
||||
value = torch.tanh(self.value_fc2(value))
|
||||
|
||||
return F.softmax(policy, dim=1), value
|
||||
|
||||
def predict(self, board):
|
||||
board = torch.as_tensor(board, dtype=torch.float32, device=self.device)
|
||||
board = board.view(1, *self.input_shape)
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
pi, v = self.forward(board)
|
||||
|
||||
return pi.detach().cpu().numpy()[0], float(v.item())
|
||||
384
run.py
Normal file
384
run.py
Normal file
@@ -0,0 +1,384 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from game import UltimateTicTacToe
|
||||
from mcts import MCTS
|
||||
from model import UltimateTicTacToeModel
|
||||
from trainer import Trainer
|
||||
|
||||
|
||||
DEFAULT_ARGS = {
|
||||
"num_simulations": 100,
|
||||
"numIters": 50,
|
||||
"numEps": 20,
|
||||
"epochs": 5,
|
||||
"batch_size": 64,
|
||||
"lr": 5e-4,
|
||||
"weight_decay": 1e-4,
|
||||
"replay_buffer_size": 50000,
|
||||
"value_loss_weight": 1.0,
|
||||
"grad_clip_norm": 5.0,
|
||||
"checkpoint_path": "latest.pth",
|
||||
"temperature_threshold": 10,
|
||||
"root_dirichlet_alpha": 0.3,
|
||||
"root_exploration_fraction": 0.25,
|
||||
"arena_compare_games": 6,
|
||||
"arena_accept_threshold": 0.55,
|
||||
"arena_compare_simulations": 8,
|
||||
}
|
||||
|
||||
|
||||
def get_device(device_arg):
|
||||
if device_arg:
|
||||
return device_arg
|
||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def build_model(game, device):
|
||||
return UltimateTicTacToeModel(
|
||||
game.get_board_size(),
|
||||
game.get_action_size(),
|
||||
device,
|
||||
)
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, device, optimizer=None, required=True):
|
||||
checkpoint = Path(checkpoint_path)
|
||||
if not checkpoint.exists():
|
||||
if required:
|
||||
raise FileNotFoundError(f"Checkpoint not found: {checkpoint}")
|
||||
return False
|
||||
|
||||
state = torch.load(checkpoint, map_location=device)
|
||||
model.load_state_dict(state["state_dict"])
|
||||
if optimizer is not None and "optimizer_state_dict" in state:
|
||||
optimizer.load_state_dict(state["optimizer_state_dict"])
|
||||
model.eval()
|
||||
return True
|
||||
|
||||
|
||||
def canonical_state(game, state, player):
|
||||
board_data, active_board = state
|
||||
return (game.get_canonical_board_data(board_data, player), active_board)
|
||||
|
||||
|
||||
def apply_moves(game, moves):
|
||||
state = game.get_init_board()
|
||||
player = 1
|
||||
for action in moves:
|
||||
next_state = game.get_next_state(state, player, action, verify_move=True)
|
||||
if next_state is False:
|
||||
raise ValueError(f"Illegal move in sequence: {action}")
|
||||
state, player = next_state
|
||||
return state, player
|
||||
|
||||
|
||||
def format_board(board_data):
|
||||
symbols = {1: "X", -1: "O", 0: "."}
|
||||
rows = []
|
||||
for row in range(9):
|
||||
cells = [symbols[int(board_data[row * 9 + col])] for col in range(9)]
|
||||
groups = [" ".join(cells[idx:idx + 3]) for idx in (0, 3, 6)]
|
||||
rows.append(" | ".join(groups))
|
||||
if row in (2, 5):
|
||||
rows.append("-" * 23)
|
||||
return "\n".join(rows)
|
||||
|
||||
|
||||
def top_policy_moves(policy, limit):
|
||||
ranked = np.argsort(policy)[::-1][:limit]
|
||||
return [(int(action), float(policy[action])) for action in ranked]
|
||||
|
||||
|
||||
def parse_moves(text):
|
||||
if not text:
|
||||
return []
|
||||
return [int(part.strip()) for part in text.split(",") if part.strip()]
|
||||
|
||||
|
||||
def parse_action(text):
|
||||
raw = text.strip().replace(",", " ").split()
|
||||
if len(raw) == 1:
|
||||
action = int(raw[0])
|
||||
elif len(raw) == 2:
|
||||
row, col = (int(value) for value in raw)
|
||||
if not (0 <= row < 9 and 0 <= col < 9):
|
||||
raise ValueError("Row and column must be in [0, 8].")
|
||||
action = row * 9 + col
|
||||
else:
|
||||
raise ValueError("Enter either a flat move index or 'row col'.")
|
||||
if not (0 <= action < 81):
|
||||
raise ValueError("Move index must be in [0, 80].")
|
||||
return action
|
||||
|
||||
|
||||
def scalar_value(value):
|
||||
return float(np.asarray(value).reshape(-1)[0])
|
||||
|
||||
|
||||
def train_command(args):
|
||||
device = get_device(args.device)
|
||||
game = UltimateTicTacToe()
|
||||
model = build_model(game, device)
|
||||
|
||||
train_args = dict(DEFAULT_ARGS)
|
||||
train_args.update(
|
||||
{
|
||||
"num_simulations": args.num_simulations,
|
||||
"numIters": args.num_iters,
|
||||
"numEps": args.num_eps,
|
||||
"epochs": args.epochs,
|
||||
"batch_size": args.batch_size,
|
||||
"lr": args.lr,
|
||||
"weight_decay": args.weight_decay,
|
||||
"replay_buffer_size": args.replay_buffer_size,
|
||||
"value_loss_weight": args.value_loss_weight,
|
||||
"grad_clip_norm": args.grad_clip_norm,
|
||||
"checkpoint_path": args.checkpoint,
|
||||
"temperature_threshold": args.temperature_threshold,
|
||||
"root_dirichlet_alpha": args.root_dirichlet_alpha,
|
||||
"root_exploration_fraction": args.root_exploration_fraction,
|
||||
"arena_compare_games": args.arena_compare_games,
|
||||
"arena_accept_threshold": args.arena_accept_threshold,
|
||||
"arena_compare_simulations": args.arena_compare_simulations,
|
||||
}
|
||||
)
|
||||
|
||||
trainer = Trainer(game, model, train_args)
|
||||
if args.resume:
|
||||
load_checkpoint(model, args.checkpoint, device, optimizer=trainer.optimizer)
|
||||
trainer.learn()
|
||||
|
||||
|
||||
def eval_command(args):
|
||||
device = get_device(args.device)
|
||||
game = UltimateTicTacToe()
|
||||
model = build_model(game, device)
|
||||
load_checkpoint(model, args.checkpoint, device)
|
||||
|
||||
moves = parse_moves(args.moves)
|
||||
state, player = apply_moves(game, moves)
|
||||
current_state = canonical_state(game, state, player)
|
||||
encoded = game.encode_state(current_state)
|
||||
policy, value = model.predict(encoded)
|
||||
legal_mask = np.array(game.get_valid_moves(state), dtype=np.float32)
|
||||
policy = policy * legal_mask
|
||||
if policy.sum() > 0:
|
||||
policy = policy / policy.sum()
|
||||
|
||||
print("Board:")
|
||||
print(format_board(state[0]))
|
||||
print()
|
||||
print(f"Side to move: {'X' if player == 1 else 'O'}")
|
||||
print(f"Active small board: {state[1]}")
|
||||
print(f"Model value: {scalar_value(value):.4f}")
|
||||
print("Top policy moves:")
|
||||
for action, prob in top_policy_moves(policy, args.top_k):
|
||||
print(f" {action:2d} -> {prob:.4f}")
|
||||
|
||||
if args.with_mcts:
|
||||
mcts_args = dict(DEFAULT_ARGS)
|
||||
mcts_args.update(
|
||||
{
|
||||
"num_simulations": args.num_simulations,
|
||||
"root_dirichlet_alpha": None,
|
||||
"root_exploration_fraction": None,
|
||||
}
|
||||
)
|
||||
root = MCTS(game, model, mcts_args).run(model, current_state, to_play=1)
|
||||
action = root.select_action(temperature=0)
|
||||
print(f"MCTS best move: {action}")
|
||||
|
||||
|
||||
def ai_action(game, model, state, player, num_simulations):
|
||||
current_state = canonical_state(game, state, player)
|
||||
mcts_args = dict(DEFAULT_ARGS)
|
||||
mcts_args.update(
|
||||
{
|
||||
"num_simulations": num_simulations,
|
||||
"root_dirichlet_alpha": None,
|
||||
"root_exploration_fraction": None,
|
||||
}
|
||||
)
|
||||
root = MCTS(game, model, mcts_args).run(model, current_state, to_play=1)
|
||||
return root.select_action(temperature=0)
|
||||
|
||||
|
||||
def random_action(game, state):
|
||||
legal_actions = [index for index, allowed in enumerate(game.get_valid_moves(state)) if allowed]
|
||||
if not legal_actions:
|
||||
raise ValueError("No legal actions available.")
|
||||
return int(np.random.choice(legal_actions))
|
||||
|
||||
|
||||
def load_player_model(game, checkpoint, device):
|
||||
model = build_model(game, device)
|
||||
load_checkpoint(model, checkpoint, device)
|
||||
return model
|
||||
|
||||
|
||||
def choose_action(game, player_kind, model, state, player, num_simulations):
|
||||
if player_kind == "random":
|
||||
return random_action(game, state)
|
||||
return ai_action(game, model, state, player, num_simulations)
|
||||
|
||||
|
||||
def play_match(game, x_kind, x_model, o_kind, o_model, num_simulations):
|
||||
state = game.get_init_board()
|
||||
player = 1
|
||||
|
||||
while True:
|
||||
reward = game.get_reward_for_player(state, player)
|
||||
if reward is not None:
|
||||
if reward == 0:
|
||||
return 0
|
||||
return player if reward == 1 else -player
|
||||
|
||||
if player == 1:
|
||||
action = choose_action(game, x_kind, x_model, state, player, num_simulations)
|
||||
else:
|
||||
action = choose_action(game, o_kind, o_model, state, player, num_simulations)
|
||||
state, player = game.get_next_state(state, player, action)
|
||||
|
||||
|
||||
def arena_command(args):
|
||||
device = get_device(args.device)
|
||||
game = UltimateTicTacToe()
|
||||
|
||||
x_model = None
|
||||
o_model = None
|
||||
if args.x_player == "checkpoint":
|
||||
x_model = load_player_model(game, args.x_checkpoint, device)
|
||||
if args.o_player == "checkpoint":
|
||||
o_model = load_player_model(game, args.o_checkpoint, device)
|
||||
|
||||
results = {1: 0, -1: 0, 0: 0}
|
||||
for _ in range(args.games):
|
||||
winner = play_match(
|
||||
game,
|
||||
args.x_player,
|
||||
x_model,
|
||||
args.o_player,
|
||||
o_model,
|
||||
args.num_simulations,
|
||||
)
|
||||
results[winner] += 1
|
||||
|
||||
print(f"Games: {args.games}")
|
||||
print(f"X ({args.x_player}) wins: {results[1]}")
|
||||
print(f"O ({args.o_player}) wins: {results[-1]}")
|
||||
print(f"Draws: {results[0]}")
|
||||
|
||||
|
||||
def play_command(args):
|
||||
device = get_device(args.device)
|
||||
game = UltimateTicTacToe()
|
||||
model = build_model(game, device)
|
||||
load_checkpoint(model, args.checkpoint, device)
|
||||
|
||||
state = game.get_init_board()
|
||||
player = 1
|
||||
human_player = args.human_player
|
||||
|
||||
while True:
|
||||
print()
|
||||
print(format_board(state[0]))
|
||||
print(f"Turn: {'X' if player == 1 else 'O'}")
|
||||
print(f"Active small board: {state[1]}")
|
||||
|
||||
reward = game.get_reward_for_player(state, player)
|
||||
if reward is not None:
|
||||
if reward == 0:
|
||||
print("Result: draw")
|
||||
else:
|
||||
winner = player if reward == 1 else -player
|
||||
print(f"Winner: {'X' if winner == 1 else 'O'}")
|
||||
return
|
||||
|
||||
valid_moves = game.get_valid_moves(state)
|
||||
legal_actions = [index for index, allowed in enumerate(valid_moves) if allowed]
|
||||
print(f"Legal moves: {legal_actions}")
|
||||
|
||||
if player == human_player:
|
||||
while True:
|
||||
try:
|
||||
action = parse_action(input("Your move (index or 'row col'): "))
|
||||
next_state = game.get_next_state(state, player, action, verify_move=True)
|
||||
if next_state is False:
|
||||
raise ValueError(f"Illegal move: {action}")
|
||||
state, player = next_state
|
||||
break
|
||||
except ValueError as exc:
|
||||
print(exc)
|
||||
else:
|
||||
action = ai_action(game, model, state, player, args.num_simulations)
|
||||
print(f"AI move: {action}")
|
||||
state, player = game.get_next_state(state, player, action)
|
||||
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(description="Ultimate Tic-Tac-Toe Runner")
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
train_parser = subparsers.add_parser("train", help="Train the model with self-play")
|
||||
train_parser.add_argument("--device")
|
||||
train_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
|
||||
train_parser.add_argument("--resume", action="store_true")
|
||||
train_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
|
||||
train_parser.add_argument("--num-iters", type=int, default=DEFAULT_ARGS["numIters"])
|
||||
train_parser.add_argument("--num-eps", type=int, default=DEFAULT_ARGS["numEps"])
|
||||
train_parser.add_argument("--epochs", type=int, default=DEFAULT_ARGS["epochs"])
|
||||
train_parser.add_argument("--batch-size", type=int, default=DEFAULT_ARGS["batch_size"])
|
||||
train_parser.add_argument("--lr", type=float, default=DEFAULT_ARGS["lr"])
|
||||
train_parser.add_argument("--weight-decay", type=float, default=DEFAULT_ARGS["weight_decay"])
|
||||
train_parser.add_argument("--replay-buffer-size", type=int, default=DEFAULT_ARGS["replay_buffer_size"])
|
||||
train_parser.add_argument("--value-loss-weight", type=float, default=DEFAULT_ARGS["value_loss_weight"])
|
||||
train_parser.add_argument("--grad-clip-norm", type=float, default=DEFAULT_ARGS["grad_clip_norm"])
|
||||
train_parser.add_argument("--temperature-threshold", type=int, default=DEFAULT_ARGS["temperature_threshold"])
|
||||
train_parser.add_argument("--root-dirichlet-alpha", type=float, default=DEFAULT_ARGS["root_dirichlet_alpha"])
|
||||
train_parser.add_argument("--root-exploration-fraction", type=float, default=DEFAULT_ARGS["root_exploration_fraction"])
|
||||
train_parser.add_argument("--arena-compare-games", type=int, default=DEFAULT_ARGS["arena_compare_games"])
|
||||
train_parser.add_argument("--arena-accept-threshold", type=float, default=DEFAULT_ARGS["arena_accept_threshold"])
|
||||
train_parser.add_argument("--arena-compare-simulations", type=int, default=DEFAULT_ARGS["arena_compare_simulations"])
|
||||
train_parser.set_defaults(func=train_command)
|
||||
|
||||
eval_parser = subparsers.add_parser("eval", help="Inspect a checkpoint on a position")
|
||||
eval_parser.add_argument("--device")
|
||||
eval_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
|
||||
eval_parser.add_argument("--moves", default="", help="Comma-separated move sequence")
|
||||
eval_parser.add_argument("--top-k", type=int, default=10)
|
||||
eval_parser.add_argument("--with-mcts", action="store_true")
|
||||
eval_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
|
||||
eval_parser.set_defaults(func=eval_command)
|
||||
|
||||
play_parser = subparsers.add_parser("play", help="Play against the checkpoint")
|
||||
play_parser.add_argument("--device")
|
||||
play_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
|
||||
play_parser.add_argument("--human-player", type=int, choices=[1, -1], default=1)
|
||||
play_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
|
||||
play_parser.set_defaults(func=play_command)
|
||||
|
||||
arena_parser = subparsers.add_parser("arena", help="Run repeated matches between agents")
|
||||
arena_parser.add_argument("--device")
|
||||
arena_parser.add_argument("--games", type=int, default=20)
|
||||
arena_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
|
||||
arena_parser.add_argument("--x-player", choices=["checkpoint", "random"], default="checkpoint")
|
||||
arena_parser.add_argument("--o-player", choices=["checkpoint", "random"], default="random")
|
||||
arena_parser.add_argument("--x-checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
|
||||
arena_parser.add_argument("--o-checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
|
||||
arena_parser.set_defaults(func=arena_command)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
258
trainer.py
Normal file
258
trainer.py
Normal file
@@ -0,0 +1,258 @@
|
||||
import os
|
||||
import math
|
||||
import copy
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from random import shuffle
|
||||
from progressbar import ProgressBar, Percentage, Bar, ETA, AdaptiveETA, FormatLabel
|
||||
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from mcts import MCTS
|
||||
|
||||
loss = {'ploss': float('inf'), 'pkl': float('inf'), 'vloss': float('inf'), 'current': 0, 'max': 0}
|
||||
|
||||
class LossWidget:
|
||||
def __call__(self, progress, data=None):
|
||||
return (
|
||||
f" {loss['current']}/{loss['max']} "
|
||||
f"P.CE: {loss['ploss']:.4f}, P.KL: {loss['pkl']:.4f}, V.Loss: {loss['vloss']:.4f}"
|
||||
)
|
||||
|
||||
class Trainer:
|
||||
|
||||
def __init__(self, game, model, args):
|
||||
self.game = game
|
||||
self.model = model
|
||||
self.args = args
|
||||
self.mcts = MCTS(self.game, self.model, self.args)
|
||||
self.optimizer = optim.Adam(
|
||||
self.model.parameters(),
|
||||
lr=self.args.get('lr', 5e-4),
|
||||
weight_decay=self.args.get('weight_decay', 1e-4),
|
||||
)
|
||||
self.replay_buffer = deque(maxlen=self.args.get('replay_buffer_size', 50000))
|
||||
|
||||
def _ai_action(self, model, state, player):
|
||||
board_data, active_board = state
|
||||
canonical_state = (self.game.get_canonical_board_data(board_data, player), active_board)
|
||||
mcts_args = dict(self.args)
|
||||
mcts_args.update(
|
||||
{
|
||||
'num_simulations': self.args.get('arena_compare_simulations', self.args['num_simulations']),
|
||||
'root_dirichlet_alpha': None,
|
||||
'root_exploration_fraction': None,
|
||||
}
|
||||
)
|
||||
root = MCTS(self.game, model, mcts_args).run(model, canonical_state, to_play=1)
|
||||
return root.select_action(temperature=0)
|
||||
|
||||
def _play_arena_game(self, x_model, o_model):
|
||||
state = self.game.get_init_board()
|
||||
current_player = 1
|
||||
|
||||
while True:
|
||||
reward = self.game.get_reward_for_player(state, current_player)
|
||||
if reward is not None:
|
||||
if reward == 0:
|
||||
return 0
|
||||
return current_player if reward == 1 else -current_player
|
||||
|
||||
model = x_model if current_player == 1 else o_model
|
||||
action = self._ai_action(model, state, current_player)
|
||||
state, current_player = self.game.get_next_state(state, current_player, action)
|
||||
|
||||
def evaluate_candidate(self, candidate_model, reference_model):
|
||||
games = self.args.get('arena_compare_games', 0)
|
||||
if games <= 0:
|
||||
return True, 0.5
|
||||
|
||||
candidate_points = 0.0
|
||||
required_points = self.args.get('arena_accept_threshold', 0.55) * games
|
||||
candidate_first_games = (games + 1) // 2
|
||||
candidate_second_games = games // 2
|
||||
games_played = 0
|
||||
|
||||
for _ in range(candidate_first_games):
|
||||
winner = self._play_arena_game(candidate_model, reference_model)
|
||||
if winner == 1:
|
||||
candidate_points += 1.0
|
||||
elif winner == 0:
|
||||
candidate_points += 0.5
|
||||
games_played += 1
|
||||
remaining_games = games - games_played
|
||||
if candidate_points >= required_points:
|
||||
return True, candidate_points / games
|
||||
if candidate_points + remaining_games < required_points:
|
||||
return False, candidate_points / games
|
||||
|
||||
for _ in range(candidate_second_games):
|
||||
winner = self._play_arena_game(reference_model, candidate_model)
|
||||
if winner == -1:
|
||||
candidate_points += 1.0
|
||||
elif winner == 0:
|
||||
candidate_points += 0.5
|
||||
games_played += 1
|
||||
remaining_games = games - games_played
|
||||
if candidate_points >= required_points:
|
||||
return True, candidate_points / games
|
||||
if candidate_points + remaining_games < required_points:
|
||||
return False, candidate_points / games
|
||||
|
||||
score = candidate_points / games
|
||||
return score >= self.args.get('arena_accept_threshold', 0.55), score
|
||||
|
||||
def exceute_episode(self):
|
||||
|
||||
train_examples = []
|
||||
current_player = 1
|
||||
state = self.game.get_init_board()
|
||||
episode_step = 0
|
||||
|
||||
while True:
|
||||
board_data, state_inner_board = state
|
||||
cannonical_board_data = self.game.get_canonical_board_data(board_data, current_player)
|
||||
canonical_board = (cannonical_board_data, state_inner_board)
|
||||
|
||||
self.mcts = MCTS(self.game, self.model, self.args)
|
||||
root = self.mcts.run(self.model, canonical_board, to_play=1)
|
||||
|
||||
action_probs = np.zeros(self.game.get_action_size(), dtype=np.float32)
|
||||
for k, v in root.children.items():
|
||||
action_probs[k] = v.visit_count
|
||||
|
||||
action_probs = action_probs / np.sum(action_probs)
|
||||
encoded_state = self.game.encode_state(canonical_board)
|
||||
train_examples.append((encoded_state, current_player, action_probs))
|
||||
|
||||
temperature_threshold = self.args.get('temperature_threshold', 10)
|
||||
temperature = 1 if episode_step < temperature_threshold else 0
|
||||
action = root.select_action(temperature=temperature)
|
||||
state, current_player = self.game.get_next_state(state, current_player, action)
|
||||
reward = self.game.get_reward_for_player(state, current_player)
|
||||
episode_step += 1
|
||||
|
||||
if reward is not None:
|
||||
ret = []
|
||||
for hist_state, hist_current_player, hist_action_probs in train_examples:
|
||||
# [Board, currentPlayer, actionProbabilities, Reward]
|
||||
ret.append((hist_state, hist_action_probs, reward * ((-1) ** (hist_current_player != current_player))))
|
||||
|
||||
return ret
|
||||
|
||||
def learn(self):
|
||||
widgets = [Percentage(), Bar(), AdaptiveETA(), LossWidget()]
|
||||
pbar = ProgressBar(max_value=self.args['numIters'], widgets=widgets)
|
||||
pbar.update(0)
|
||||
for i in range(1, self.args['numIters'] + 1):
|
||||
|
||||
# print("{}/{}".format(i, self.args['numIters']))
|
||||
|
||||
train_examples = []
|
||||
|
||||
for eps in range(self.args['numEps']):
|
||||
iteration_train_examples = self.exceute_episode()
|
||||
train_examples.extend(iteration_train_examples)
|
||||
|
||||
self.replay_buffer.extend(train_examples)
|
||||
training_examples = list(self.replay_buffer)
|
||||
shuffle(training_examples)
|
||||
shuffle(train_examples)
|
||||
|
||||
reference_model = copy.deepcopy(self.model)
|
||||
reference_model.eval()
|
||||
reference_state_dict = copy.deepcopy(self.model.state_dict())
|
||||
reference_optimizer_state = copy.deepcopy(self.optimizer.state_dict())
|
||||
|
||||
results = self.train(training_examples)
|
||||
accepted, arena_score = self.evaluate_candidate(self.model, reference_model)
|
||||
if accepted:
|
||||
filename = self.args['checkpoint_path']
|
||||
self.save_checkpoint(folder=".", filename=filename)
|
||||
else:
|
||||
self.model.load_state_dict(reference_state_dict)
|
||||
self.optimizer.load_state_dict(reference_optimizer_state)
|
||||
|
||||
# print((float(results[0]), float(results[1])))
|
||||
loss['ploss'] = float(results[0])
|
||||
loss['pkl'] = float(results[1])
|
||||
loss['vloss'] = float(results[2])
|
||||
loss['current'] = i
|
||||
loss['max'] = self.args['numIters']
|
||||
# loss_widget.update(float(results[0]), float(results[1]))
|
||||
pbar.update(i)
|
||||
pbar.finish()
|
||||
|
||||
def train(self, examples):
|
||||
pi_losses = []
|
||||
pi_kls = []
|
||||
v_losses = []
|
||||
device = self.model.device
|
||||
|
||||
for epoch in range(self.args['epochs']):
|
||||
self.model.train()
|
||||
shuffled_indices = np.random.permutation(len(examples))
|
||||
|
||||
for batch_start in range(0, len(examples), self.args['batch_size']):
|
||||
sample_ids = shuffled_indices[batch_start:batch_start + self.args['batch_size']]
|
||||
boards, pis, vs = list(zip(*[examples[i] for i in sample_ids]))
|
||||
boards = torch.FloatTensor(np.array(boards).astype(np.float32))
|
||||
target_pis = torch.FloatTensor(np.array(pis))
|
||||
target_vs = torch.FloatTensor(np.array(vs).astype(np.float32))
|
||||
|
||||
# predict
|
||||
boards = boards.contiguous().to(device)
|
||||
target_pis = target_pis.contiguous().to(device)
|
||||
target_vs = target_vs.contiguous().to(device)
|
||||
|
||||
# compute output
|
||||
out_pi, out_v = self.model(boards)
|
||||
l_pi = self.loss_pi(target_pis, out_pi)
|
||||
l_pi_kl = self.loss_pi_kl(target_pis, out_pi)
|
||||
l_v = self.loss_v(target_vs, out_v)
|
||||
total_loss = l_pi_kl + self.args.get('value_loss_weight', 1.0) * l_v
|
||||
|
||||
pi_losses.append(float(l_pi.detach()))
|
||||
pi_kls.append(float(l_pi_kl.detach()))
|
||||
v_losses.append(float(l_v.detach()))
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
total_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.get('grad_clip_norm', 5.0))
|
||||
self.optimizer.step()
|
||||
|
||||
# print()
|
||||
# print("Policy Loss", np.mean(pi_losses))
|
||||
# print("Value Loss", np.mean(v_losses))
|
||||
|
||||
return (np.mean(pi_losses), np.mean(pi_kls), np.mean(v_losses))
|
||||
# print("Examples:")
|
||||
# print(out_pi[0].detach())
|
||||
# print(target_pis[0])
|
||||
|
||||
def loss_pi(self, targets, outputs):
|
||||
loss = -(targets * torch.log(outputs.clamp_min(1e-8))).sum(dim=1)
|
||||
return loss.mean()
|
||||
|
||||
def loss_pi_kl(self, targets, outputs):
|
||||
target_log = torch.log(targets.clamp_min(1e-8))
|
||||
output_log = torch.log(outputs.clamp_min(1e-8))
|
||||
loss = (targets * (target_log - output_log)).sum(dim=1)
|
||||
return loss.mean()
|
||||
|
||||
def loss_v(self, targets, outputs):
|
||||
loss = torch.sum((targets-outputs.view(-1))**2)/targets.size()[0]
|
||||
return loss
|
||||
|
||||
def save_checkpoint(self, folder, filename):
|
||||
if not os.path.exists(folder):
|
||||
os.mkdir(folder)
|
||||
|
||||
filepath = os.path.join(folder, filename)
|
||||
torch.save({
|
||||
'state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'args': self.args,
|
||||
}, filepath)
|
||||
Reference in New Issue
Block a user