import os import math import copy import numpy as np from collections import deque from random import shuffle from progressbar import ProgressBar, Percentage, Bar, ETA, AdaptiveETA, FormatLabel import torch import torch.optim as optim from mcts import MCTS loss = {'ploss': float('inf'), 'pkl': float('inf'), 'vloss': float('inf'), 'current': 0, 'max': 0} class LossWidget: def __call__(self, progress, data=None): return ( f" {loss['current']}/{loss['max']} " f"P.CE: {loss['ploss']:.4f}, P.KL: {loss['pkl']:.4f}, V.Loss: {loss['vloss']:.4f}" ) class Trainer: def __init__(self, game, model, args): self.game = game self.model = model self.args = args self.mcts = MCTS(self.game, self.model, self.args) self.optimizer = optim.Adam( self.model.parameters(), lr=self.args.get('lr', 5e-4), weight_decay=self.args.get('weight_decay', 1e-4), ) self.replay_buffer = deque(maxlen=self.args.get('replay_buffer_size', 50000)) def _ai_action(self, model, state, player): board_data, active_board = state canonical_state = (self.game.get_canonical_board_data(board_data, player), active_board) mcts_args = dict(self.args) mcts_args.update( { 'num_simulations': self.args.get('arena_compare_simulations', self.args['num_simulations']), 'root_dirichlet_alpha': None, 'root_exploration_fraction': None, } ) root = MCTS(self.game, model, mcts_args).run(model, canonical_state, to_play=1) return root.select_action(temperature=0) def _play_arena_game(self, x_model, o_model): state = self.game.get_init_board() current_player = 1 while True: reward = self.game.get_reward_for_player(state, current_player) if reward is not None: if reward == 0: return 0 return current_player if reward == 1 else -current_player model = x_model if current_player == 1 else o_model action = self._ai_action(model, state, current_player) state, current_player = self.game.get_next_state(state, current_player, action) def evaluate_candidate(self, candidate_model, reference_model): games = self.args.get('arena_compare_games', 0) if games <= 0: return True, 0.5 candidate_points = 0.0 required_points = self.args.get('arena_accept_threshold', 0.55) * games candidate_first_games = (games + 1) // 2 candidate_second_games = games // 2 games_played = 0 for _ in range(candidate_first_games): winner = self._play_arena_game(candidate_model, reference_model) if winner == 1: candidate_points += 1.0 elif winner == 0: candidate_points += 0.5 games_played += 1 remaining_games = games - games_played if candidate_points >= required_points: return True, candidate_points / games if candidate_points + remaining_games < required_points: return False, candidate_points / games for _ in range(candidate_second_games): winner = self._play_arena_game(reference_model, candidate_model) if winner == -1: candidate_points += 1.0 elif winner == 0: candidate_points += 0.5 games_played += 1 remaining_games = games - games_played if candidate_points >= required_points: return True, candidate_points / games if candidate_points + remaining_games < required_points: return False, candidate_points / games score = candidate_points / games return score >= self.args.get('arena_accept_threshold', 0.55), score def exceute_episode(self): train_examples = [] current_player = 1 state = self.game.get_init_board() episode_step = 0 while True: board_data, state_inner_board = state cannonical_board_data = self.game.get_canonical_board_data(board_data, current_player) canonical_board = (cannonical_board_data, state_inner_board) self.mcts = MCTS(self.game, self.model, self.args) root = self.mcts.run(self.model, canonical_board, to_play=1) action_probs = np.zeros(self.game.get_action_size(), dtype=np.float32) for k, v in root.children.items(): action_probs[k] = v.visit_count action_probs = action_probs / np.sum(action_probs) encoded_state = self.game.encode_state(canonical_board) train_examples.append((encoded_state, current_player, action_probs)) temperature_threshold = self.args.get('temperature_threshold', 10) temperature = 1 if episode_step < temperature_threshold else 0 action = root.select_action(temperature=temperature) state, current_player = self.game.get_next_state(state, current_player, action) reward = self.game.get_reward_for_player(state, current_player) episode_step += 1 if reward is not None: ret = [] for hist_state, hist_current_player, hist_action_probs in train_examples: # [Board, currentPlayer, actionProbabilities, Reward] ret.append((hist_state, hist_action_probs, reward * ((-1) ** (hist_current_player != current_player)))) return ret def learn(self): widgets = [Percentage(), Bar(), AdaptiveETA(), LossWidget()] pbar = ProgressBar(max_value=self.args['numIters'], widgets=widgets) pbar.update(0) for i in range(1, self.args['numIters'] + 1): # print("{}/{}".format(i, self.args['numIters'])) train_examples = [] for eps in range(self.args['numEps']): iteration_train_examples = self.exceute_episode() train_examples.extend(iteration_train_examples) self.replay_buffer.extend(train_examples) training_examples = list(self.replay_buffer) shuffle(training_examples) shuffle(train_examples) reference_model = copy.deepcopy(self.model) reference_model.eval() reference_state_dict = copy.deepcopy(self.model.state_dict()) reference_optimizer_state = copy.deepcopy(self.optimizer.state_dict()) results = self.train(training_examples) accepted, arena_score = self.evaluate_candidate(self.model, reference_model) if accepted: filename = self.args['checkpoint_path'] self.save_checkpoint(folder=".", filename=filename) else: self.model.load_state_dict(reference_state_dict) self.optimizer.load_state_dict(reference_optimizer_state) # print((float(results[0]), float(results[1]))) loss['ploss'] = float(results[0]) loss['pkl'] = float(results[1]) loss['vloss'] = float(results[2]) loss['current'] = i loss['max'] = self.args['numIters'] # loss_widget.update(float(results[0]), float(results[1])) pbar.update(i) pbar.finish() def train(self, examples): pi_losses = [] pi_kls = [] v_losses = [] device = self.model.device for epoch in range(self.args['epochs']): self.model.train() shuffled_indices = np.random.permutation(len(examples)) for batch_start in range(0, len(examples), self.args['batch_size']): sample_ids = shuffled_indices[batch_start:batch_start + self.args['batch_size']] boards, pis, vs = list(zip(*[examples[i] for i in sample_ids])) boards = torch.FloatTensor(np.array(boards).astype(np.float32)) target_pis = torch.FloatTensor(np.array(pis)) target_vs = torch.FloatTensor(np.array(vs).astype(np.float32)) # predict boards = boards.contiguous().to(device) target_pis = target_pis.contiguous().to(device) target_vs = target_vs.contiguous().to(device) # compute output out_pi, out_v = self.model(boards) l_pi = self.loss_pi(target_pis, out_pi) l_pi_kl = self.loss_pi_kl(target_pis, out_pi) l_v = self.loss_v(target_vs, out_v) total_loss = l_pi_kl + self.args.get('value_loss_weight', 1.0) * l_v pi_losses.append(float(l_pi.detach())) pi_kls.append(float(l_pi_kl.detach())) v_losses.append(float(l_v.detach())) self.optimizer.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.get('grad_clip_norm', 5.0)) self.optimizer.step() # print() # print("Policy Loss", np.mean(pi_losses)) # print("Value Loss", np.mean(v_losses)) return (np.mean(pi_losses), np.mean(pi_kls), np.mean(v_losses)) # print("Examples:") # print(out_pi[0].detach()) # print(target_pis[0]) def loss_pi(self, targets, outputs): loss = -(targets * torch.log(outputs.clamp_min(1e-8))).sum(dim=1) return loss.mean() def loss_pi_kl(self, targets, outputs): target_log = torch.log(targets.clamp_min(1e-8)) output_log = torch.log(outputs.clamp_min(1e-8)) loss = (targets * (target_log - output_log)).sum(dim=1) return loss.mean() def loss_v(self, targets, outputs): loss = torch.sum((targets-outputs.view(-1))**2)/targets.size()[0] return loss def save_checkpoint(self, folder, filename): if not os.path.exists(folder): os.mkdir(folder) filepath = os.path.join(folder, filename) torch.save({ 'state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'args': self.args, }, filepath)