First commit
This commit is contained in:
258
trainer.py
Normal file
258
trainer.py
Normal file
@@ -0,0 +1,258 @@
|
||||
import os
|
||||
import math
|
||||
import copy
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from random import shuffle
|
||||
from progressbar import ProgressBar, Percentage, Bar, ETA, AdaptiveETA, FormatLabel
|
||||
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from mcts import MCTS
|
||||
|
||||
loss = {'ploss': float('inf'), 'pkl': float('inf'), 'vloss': float('inf'), 'current': 0, 'max': 0}
|
||||
|
||||
class LossWidget:
|
||||
def __call__(self, progress, data=None):
|
||||
return (
|
||||
f" {loss['current']}/{loss['max']} "
|
||||
f"P.CE: {loss['ploss']:.4f}, P.KL: {loss['pkl']:.4f}, V.Loss: {loss['vloss']:.4f}"
|
||||
)
|
||||
|
||||
class Trainer:
|
||||
|
||||
def __init__(self, game, model, args):
|
||||
self.game = game
|
||||
self.model = model
|
||||
self.args = args
|
||||
self.mcts = MCTS(self.game, self.model, self.args)
|
||||
self.optimizer = optim.Adam(
|
||||
self.model.parameters(),
|
||||
lr=self.args.get('lr', 5e-4),
|
||||
weight_decay=self.args.get('weight_decay', 1e-4),
|
||||
)
|
||||
self.replay_buffer = deque(maxlen=self.args.get('replay_buffer_size', 50000))
|
||||
|
||||
def _ai_action(self, model, state, player):
|
||||
board_data, active_board = state
|
||||
canonical_state = (self.game.get_canonical_board_data(board_data, player), active_board)
|
||||
mcts_args = dict(self.args)
|
||||
mcts_args.update(
|
||||
{
|
||||
'num_simulations': self.args.get('arena_compare_simulations', self.args['num_simulations']),
|
||||
'root_dirichlet_alpha': None,
|
||||
'root_exploration_fraction': None,
|
||||
}
|
||||
)
|
||||
root = MCTS(self.game, model, mcts_args).run(model, canonical_state, to_play=1)
|
||||
return root.select_action(temperature=0)
|
||||
|
||||
def _play_arena_game(self, x_model, o_model):
|
||||
state = self.game.get_init_board()
|
||||
current_player = 1
|
||||
|
||||
while True:
|
||||
reward = self.game.get_reward_for_player(state, current_player)
|
||||
if reward is not None:
|
||||
if reward == 0:
|
||||
return 0
|
||||
return current_player if reward == 1 else -current_player
|
||||
|
||||
model = x_model if current_player == 1 else o_model
|
||||
action = self._ai_action(model, state, current_player)
|
||||
state, current_player = self.game.get_next_state(state, current_player, action)
|
||||
|
||||
def evaluate_candidate(self, candidate_model, reference_model):
|
||||
games = self.args.get('arena_compare_games', 0)
|
||||
if games <= 0:
|
||||
return True, 0.5
|
||||
|
||||
candidate_points = 0.0
|
||||
required_points = self.args.get('arena_accept_threshold', 0.55) * games
|
||||
candidate_first_games = (games + 1) // 2
|
||||
candidate_second_games = games // 2
|
||||
games_played = 0
|
||||
|
||||
for _ in range(candidate_first_games):
|
||||
winner = self._play_arena_game(candidate_model, reference_model)
|
||||
if winner == 1:
|
||||
candidate_points += 1.0
|
||||
elif winner == 0:
|
||||
candidate_points += 0.5
|
||||
games_played += 1
|
||||
remaining_games = games - games_played
|
||||
if candidate_points >= required_points:
|
||||
return True, candidate_points / games
|
||||
if candidate_points + remaining_games < required_points:
|
||||
return False, candidate_points / games
|
||||
|
||||
for _ in range(candidate_second_games):
|
||||
winner = self._play_arena_game(reference_model, candidate_model)
|
||||
if winner == -1:
|
||||
candidate_points += 1.0
|
||||
elif winner == 0:
|
||||
candidate_points += 0.5
|
||||
games_played += 1
|
||||
remaining_games = games - games_played
|
||||
if candidate_points >= required_points:
|
||||
return True, candidate_points / games
|
||||
if candidate_points + remaining_games < required_points:
|
||||
return False, candidate_points / games
|
||||
|
||||
score = candidate_points / games
|
||||
return score >= self.args.get('arena_accept_threshold', 0.55), score
|
||||
|
||||
def exceute_episode(self):
|
||||
|
||||
train_examples = []
|
||||
current_player = 1
|
||||
state = self.game.get_init_board()
|
||||
episode_step = 0
|
||||
|
||||
while True:
|
||||
board_data, state_inner_board = state
|
||||
cannonical_board_data = self.game.get_canonical_board_data(board_data, current_player)
|
||||
canonical_board = (cannonical_board_data, state_inner_board)
|
||||
|
||||
self.mcts = MCTS(self.game, self.model, self.args)
|
||||
root = self.mcts.run(self.model, canonical_board, to_play=1)
|
||||
|
||||
action_probs = np.zeros(self.game.get_action_size(), dtype=np.float32)
|
||||
for k, v in root.children.items():
|
||||
action_probs[k] = v.visit_count
|
||||
|
||||
action_probs = action_probs / np.sum(action_probs)
|
||||
encoded_state = self.game.encode_state(canonical_board)
|
||||
train_examples.append((encoded_state, current_player, action_probs))
|
||||
|
||||
temperature_threshold = self.args.get('temperature_threshold', 10)
|
||||
temperature = 1 if episode_step < temperature_threshold else 0
|
||||
action = root.select_action(temperature=temperature)
|
||||
state, current_player = self.game.get_next_state(state, current_player, action)
|
||||
reward = self.game.get_reward_for_player(state, current_player)
|
||||
episode_step += 1
|
||||
|
||||
if reward is not None:
|
||||
ret = []
|
||||
for hist_state, hist_current_player, hist_action_probs in train_examples:
|
||||
# [Board, currentPlayer, actionProbabilities, Reward]
|
||||
ret.append((hist_state, hist_action_probs, reward * ((-1) ** (hist_current_player != current_player))))
|
||||
|
||||
return ret
|
||||
|
||||
def learn(self):
|
||||
widgets = [Percentage(), Bar(), AdaptiveETA(), LossWidget()]
|
||||
pbar = ProgressBar(max_value=self.args['numIters'], widgets=widgets)
|
||||
pbar.update(0)
|
||||
for i in range(1, self.args['numIters'] + 1):
|
||||
|
||||
# print("{}/{}".format(i, self.args['numIters']))
|
||||
|
||||
train_examples = []
|
||||
|
||||
for eps in range(self.args['numEps']):
|
||||
iteration_train_examples = self.exceute_episode()
|
||||
train_examples.extend(iteration_train_examples)
|
||||
|
||||
self.replay_buffer.extend(train_examples)
|
||||
training_examples = list(self.replay_buffer)
|
||||
shuffle(training_examples)
|
||||
shuffle(train_examples)
|
||||
|
||||
reference_model = copy.deepcopy(self.model)
|
||||
reference_model.eval()
|
||||
reference_state_dict = copy.deepcopy(self.model.state_dict())
|
||||
reference_optimizer_state = copy.deepcopy(self.optimizer.state_dict())
|
||||
|
||||
results = self.train(training_examples)
|
||||
accepted, arena_score = self.evaluate_candidate(self.model, reference_model)
|
||||
if accepted:
|
||||
filename = self.args['checkpoint_path']
|
||||
self.save_checkpoint(folder=".", filename=filename)
|
||||
else:
|
||||
self.model.load_state_dict(reference_state_dict)
|
||||
self.optimizer.load_state_dict(reference_optimizer_state)
|
||||
|
||||
# print((float(results[0]), float(results[1])))
|
||||
loss['ploss'] = float(results[0])
|
||||
loss['pkl'] = float(results[1])
|
||||
loss['vloss'] = float(results[2])
|
||||
loss['current'] = i
|
||||
loss['max'] = self.args['numIters']
|
||||
# loss_widget.update(float(results[0]), float(results[1]))
|
||||
pbar.update(i)
|
||||
pbar.finish()
|
||||
|
||||
def train(self, examples):
|
||||
pi_losses = []
|
||||
pi_kls = []
|
||||
v_losses = []
|
||||
device = self.model.device
|
||||
|
||||
for epoch in range(self.args['epochs']):
|
||||
self.model.train()
|
||||
shuffled_indices = np.random.permutation(len(examples))
|
||||
|
||||
for batch_start in range(0, len(examples), self.args['batch_size']):
|
||||
sample_ids = shuffled_indices[batch_start:batch_start + self.args['batch_size']]
|
||||
boards, pis, vs = list(zip(*[examples[i] for i in sample_ids]))
|
||||
boards = torch.FloatTensor(np.array(boards).astype(np.float32))
|
||||
target_pis = torch.FloatTensor(np.array(pis))
|
||||
target_vs = torch.FloatTensor(np.array(vs).astype(np.float32))
|
||||
|
||||
# predict
|
||||
boards = boards.contiguous().to(device)
|
||||
target_pis = target_pis.contiguous().to(device)
|
||||
target_vs = target_vs.contiguous().to(device)
|
||||
|
||||
# compute output
|
||||
out_pi, out_v = self.model(boards)
|
||||
l_pi = self.loss_pi(target_pis, out_pi)
|
||||
l_pi_kl = self.loss_pi_kl(target_pis, out_pi)
|
||||
l_v = self.loss_v(target_vs, out_v)
|
||||
total_loss = l_pi_kl + self.args.get('value_loss_weight', 1.0) * l_v
|
||||
|
||||
pi_losses.append(float(l_pi.detach()))
|
||||
pi_kls.append(float(l_pi_kl.detach()))
|
||||
v_losses.append(float(l_v.detach()))
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
total_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.get('grad_clip_norm', 5.0))
|
||||
self.optimizer.step()
|
||||
|
||||
# print()
|
||||
# print("Policy Loss", np.mean(pi_losses))
|
||||
# print("Value Loss", np.mean(v_losses))
|
||||
|
||||
return (np.mean(pi_losses), np.mean(pi_kls), np.mean(v_losses))
|
||||
# print("Examples:")
|
||||
# print(out_pi[0].detach())
|
||||
# print(target_pis[0])
|
||||
|
||||
def loss_pi(self, targets, outputs):
|
||||
loss = -(targets * torch.log(outputs.clamp_min(1e-8))).sum(dim=1)
|
||||
return loss.mean()
|
||||
|
||||
def loss_pi_kl(self, targets, outputs):
|
||||
target_log = torch.log(targets.clamp_min(1e-8))
|
||||
output_log = torch.log(outputs.clamp_min(1e-8))
|
||||
loss = (targets * (target_log - output_log)).sum(dim=1)
|
||||
return loss.mean()
|
||||
|
||||
def loss_v(self, targets, outputs):
|
||||
loss = torch.sum((targets-outputs.view(-1))**2)/targets.size()[0]
|
||||
return loss
|
||||
|
||||
def save_checkpoint(self, folder, filename):
|
||||
if not os.path.exists(folder):
|
||||
os.mkdir(folder)
|
||||
|
||||
filepath = os.path.join(folder, filename)
|
||||
torch.save({
|
||||
'state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'args': self.args,
|
||||
}, filepath)
|
||||
Reference in New Issue
Block a user