First commit

This commit is contained in:
2026-03-23 21:19:29 +01:00
commit 29fc731e6c
7 changed files with 1173 additions and 0 deletions

258
trainer.py Normal file
View File

@@ -0,0 +1,258 @@
import os
import math
import copy
import numpy as np
from collections import deque
from random import shuffle
from progressbar import ProgressBar, Percentage, Bar, ETA, AdaptiveETA, FormatLabel
import torch
import torch.optim as optim
from mcts import MCTS
loss = {'ploss': float('inf'), 'pkl': float('inf'), 'vloss': float('inf'), 'current': 0, 'max': 0}
class LossWidget:
def __call__(self, progress, data=None):
return (
f" {loss['current']}/{loss['max']} "
f"P.CE: {loss['ploss']:.4f}, P.KL: {loss['pkl']:.4f}, V.Loss: {loss['vloss']:.4f}"
)
class Trainer:
def __init__(self, game, model, args):
self.game = game
self.model = model
self.args = args
self.mcts = MCTS(self.game, self.model, self.args)
self.optimizer = optim.Adam(
self.model.parameters(),
lr=self.args.get('lr', 5e-4),
weight_decay=self.args.get('weight_decay', 1e-4),
)
self.replay_buffer = deque(maxlen=self.args.get('replay_buffer_size', 50000))
def _ai_action(self, model, state, player):
board_data, active_board = state
canonical_state = (self.game.get_canonical_board_data(board_data, player), active_board)
mcts_args = dict(self.args)
mcts_args.update(
{
'num_simulations': self.args.get('arena_compare_simulations', self.args['num_simulations']),
'root_dirichlet_alpha': None,
'root_exploration_fraction': None,
}
)
root = MCTS(self.game, model, mcts_args).run(model, canonical_state, to_play=1)
return root.select_action(temperature=0)
def _play_arena_game(self, x_model, o_model):
state = self.game.get_init_board()
current_player = 1
while True:
reward = self.game.get_reward_for_player(state, current_player)
if reward is not None:
if reward == 0:
return 0
return current_player if reward == 1 else -current_player
model = x_model if current_player == 1 else o_model
action = self._ai_action(model, state, current_player)
state, current_player = self.game.get_next_state(state, current_player, action)
def evaluate_candidate(self, candidate_model, reference_model):
games = self.args.get('arena_compare_games', 0)
if games <= 0:
return True, 0.5
candidate_points = 0.0
required_points = self.args.get('arena_accept_threshold', 0.55) * games
candidate_first_games = (games + 1) // 2
candidate_second_games = games // 2
games_played = 0
for _ in range(candidate_first_games):
winner = self._play_arena_game(candidate_model, reference_model)
if winner == 1:
candidate_points += 1.0
elif winner == 0:
candidate_points += 0.5
games_played += 1
remaining_games = games - games_played
if candidate_points >= required_points:
return True, candidate_points / games
if candidate_points + remaining_games < required_points:
return False, candidate_points / games
for _ in range(candidate_second_games):
winner = self._play_arena_game(reference_model, candidate_model)
if winner == -1:
candidate_points += 1.0
elif winner == 0:
candidate_points += 0.5
games_played += 1
remaining_games = games - games_played
if candidate_points >= required_points:
return True, candidate_points / games
if candidate_points + remaining_games < required_points:
return False, candidate_points / games
score = candidate_points / games
return score >= self.args.get('arena_accept_threshold', 0.55), score
def exceute_episode(self):
train_examples = []
current_player = 1
state = self.game.get_init_board()
episode_step = 0
while True:
board_data, state_inner_board = state
cannonical_board_data = self.game.get_canonical_board_data(board_data, current_player)
canonical_board = (cannonical_board_data, state_inner_board)
self.mcts = MCTS(self.game, self.model, self.args)
root = self.mcts.run(self.model, canonical_board, to_play=1)
action_probs = np.zeros(self.game.get_action_size(), dtype=np.float32)
for k, v in root.children.items():
action_probs[k] = v.visit_count
action_probs = action_probs / np.sum(action_probs)
encoded_state = self.game.encode_state(canonical_board)
train_examples.append((encoded_state, current_player, action_probs))
temperature_threshold = self.args.get('temperature_threshold', 10)
temperature = 1 if episode_step < temperature_threshold else 0
action = root.select_action(temperature=temperature)
state, current_player = self.game.get_next_state(state, current_player, action)
reward = self.game.get_reward_for_player(state, current_player)
episode_step += 1
if reward is not None:
ret = []
for hist_state, hist_current_player, hist_action_probs in train_examples:
# [Board, currentPlayer, actionProbabilities, Reward]
ret.append((hist_state, hist_action_probs, reward * ((-1) ** (hist_current_player != current_player))))
return ret
def learn(self):
widgets = [Percentage(), Bar(), AdaptiveETA(), LossWidget()]
pbar = ProgressBar(max_value=self.args['numIters'], widgets=widgets)
pbar.update(0)
for i in range(1, self.args['numIters'] + 1):
# print("{}/{}".format(i, self.args['numIters']))
train_examples = []
for eps in range(self.args['numEps']):
iteration_train_examples = self.exceute_episode()
train_examples.extend(iteration_train_examples)
self.replay_buffer.extend(train_examples)
training_examples = list(self.replay_buffer)
shuffle(training_examples)
shuffle(train_examples)
reference_model = copy.deepcopy(self.model)
reference_model.eval()
reference_state_dict = copy.deepcopy(self.model.state_dict())
reference_optimizer_state = copy.deepcopy(self.optimizer.state_dict())
results = self.train(training_examples)
accepted, arena_score = self.evaluate_candidate(self.model, reference_model)
if accepted:
filename = self.args['checkpoint_path']
self.save_checkpoint(folder=".", filename=filename)
else:
self.model.load_state_dict(reference_state_dict)
self.optimizer.load_state_dict(reference_optimizer_state)
# print((float(results[0]), float(results[1])))
loss['ploss'] = float(results[0])
loss['pkl'] = float(results[1])
loss['vloss'] = float(results[2])
loss['current'] = i
loss['max'] = self.args['numIters']
# loss_widget.update(float(results[0]), float(results[1]))
pbar.update(i)
pbar.finish()
def train(self, examples):
pi_losses = []
pi_kls = []
v_losses = []
device = self.model.device
for epoch in range(self.args['epochs']):
self.model.train()
shuffled_indices = np.random.permutation(len(examples))
for batch_start in range(0, len(examples), self.args['batch_size']):
sample_ids = shuffled_indices[batch_start:batch_start + self.args['batch_size']]
boards, pis, vs = list(zip(*[examples[i] for i in sample_ids]))
boards = torch.FloatTensor(np.array(boards).astype(np.float32))
target_pis = torch.FloatTensor(np.array(pis))
target_vs = torch.FloatTensor(np.array(vs).astype(np.float32))
# predict
boards = boards.contiguous().to(device)
target_pis = target_pis.contiguous().to(device)
target_vs = target_vs.contiguous().to(device)
# compute output
out_pi, out_v = self.model(boards)
l_pi = self.loss_pi(target_pis, out_pi)
l_pi_kl = self.loss_pi_kl(target_pis, out_pi)
l_v = self.loss_v(target_vs, out_v)
total_loss = l_pi_kl + self.args.get('value_loss_weight', 1.0) * l_v
pi_losses.append(float(l_pi.detach()))
pi_kls.append(float(l_pi_kl.detach()))
v_losses.append(float(l_v.detach()))
self.optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.get('grad_clip_norm', 5.0))
self.optimizer.step()
# print()
# print("Policy Loss", np.mean(pi_losses))
# print("Value Loss", np.mean(v_losses))
return (np.mean(pi_losses), np.mean(pi_kls), np.mean(v_losses))
# print("Examples:")
# print(out_pi[0].detach())
# print(target_pis[0])
def loss_pi(self, targets, outputs):
loss = -(targets * torch.log(outputs.clamp_min(1e-8))).sum(dim=1)
return loss.mean()
def loss_pi_kl(self, targets, outputs):
target_log = torch.log(targets.clamp_min(1e-8))
output_log = torch.log(outputs.clamp_min(1e-8))
loss = (targets * (target_log - output_log)).sum(dim=1)
return loss.mean()
def loss_v(self, targets, outputs):
loss = torch.sum((targets-outputs.view(-1))**2)/targets.size()[0]
return loss
def save_checkpoint(self, folder, filename):
if not os.path.exists(folder):
os.mkdir(folder)
filepath = os.path.join(folder, filename)
torch.save({
'state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'args': self.args,
}, filepath)