197 lines
6.7 KiB
Python
197 lines
6.7 KiB
Python
import numpy as np
|
|
|
|
WIN_PATTERNS = [
|
|
(0, 1, 2),
|
|
(3, 4, 5),
|
|
(6, 7, 8),
|
|
(0, 3, 6),
|
|
(1, 4, 7),
|
|
(2, 5, 8),
|
|
(0, 4, 8),
|
|
(2, 4, 6),
|
|
]
|
|
|
|
class UltimateTicTacToe:
|
|
"""
|
|
A very, very simple game of ConnectX in which we have:
|
|
rows: 1
|
|
columns: 4
|
|
winNumber: 2
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.cells = 81
|
|
self.board_width = 9
|
|
self.state_planes = 9
|
|
|
|
def get_init_board(self):
|
|
b = np.zeros((self.cells,), dtype=int)
|
|
return (b, None)
|
|
|
|
def get_board_size(self):
|
|
return (self.state_planes, self.board_width, self.board_width)
|
|
|
|
def get_action_size(self):
|
|
return self.cells
|
|
|
|
def get_next_state(self, board, player, action, verify_move=False):
|
|
if verify_move:
|
|
if self.get_valid_moves(board)[action] == 0:
|
|
return False
|
|
new_board_data = np.copy(board[0])
|
|
new_board_data[action] = player
|
|
|
|
next_board = ((action // 9) % 3) * 3 + (action % 3)
|
|
next_board = next_board if not self.is_board_full(new_board_data, next_board) else None
|
|
|
|
# Return the new game, but
|
|
# change the perspective of the game with negative
|
|
return ((new_board_data, next_board), -player)
|
|
|
|
def is_board_full(self, board_data, next_board):
|
|
return self._is_small_board_win(board_data, next_board, 1) or self._is_small_board_win(board_data, next_board, -1) or self._is_board_full(board_data, next_board)
|
|
|
|
def _small_board_cells(self, inner_board_idx):
|
|
row_block = inner_board_idx // 3
|
|
col_block = inner_board_idx % 3
|
|
|
|
base = row_block * 27 + col_block * 3
|
|
|
|
return [
|
|
base, base + 1, base + 2,
|
|
base + 9, base + 10, base + 11,
|
|
base + 18, base + 19, base + 20
|
|
]
|
|
|
|
def _is_board_full(self, board_data, next_board):
|
|
# Check if it is literally full
|
|
cells = self._small_board_cells(next_board)
|
|
|
|
for a in cells:
|
|
if board_data[a] == 0:
|
|
return False
|
|
return True
|
|
|
|
def _is_playable_small_board(self, board_data, inner_board_idx):
|
|
return not self.is_board_full(board_data, inner_board_idx)
|
|
|
|
def has_legal_moves(self, board):
|
|
valid_moves = self.get_valid_moves(board)
|
|
for i in valid_moves:
|
|
if i == 1:
|
|
return True
|
|
return False
|
|
|
|
def get_valid_moves(self, board):
|
|
# All moves are invalid by default
|
|
board_data, active_board = board
|
|
valid_moves = [0] * self.get_action_size()
|
|
|
|
if active_board is not None and not self._is_playable_small_board(board_data, active_board):
|
|
active_board = None
|
|
|
|
if active_board is None:
|
|
playable_boards = [
|
|
inner_board_idx
|
|
for inner_board_idx in range(9)
|
|
if self._is_playable_small_board(board_data, inner_board_idx)
|
|
]
|
|
for inner_board_idx in playable_boards:
|
|
for index in self._small_board_cells(inner_board_idx):
|
|
if board_data[index] == 0:
|
|
valid_moves[index] = 1
|
|
else:
|
|
for index in self._small_board_cells(active_board):
|
|
if board_data[index] == 0:
|
|
valid_moves[index] = 1
|
|
|
|
return valid_moves
|
|
|
|
def _is_small_board_win(self, board_data, inner_board_idx, player):
|
|
cells = self._small_board_cells(inner_board_idx)
|
|
|
|
for a, b, c in WIN_PATTERNS:
|
|
if board_data[cells[a]] == board_data[cells[b]] == board_data[cells[c]] == player:
|
|
return True
|
|
|
|
return False
|
|
|
|
def is_win(self, board, player):
|
|
board_data, _ = board
|
|
won = [self._is_small_board_win(board_data, i, player) for i in range(9)]
|
|
|
|
# Check if any winning combination is all 1s
|
|
for a, b, c in WIN_PATTERNS:
|
|
if won[a] and won[b] and won[c]:
|
|
return True
|
|
|
|
return False
|
|
|
|
def get_reward_for_player(self, board, player):
|
|
# return None if not ended, 1 if player 1 wins, -1 if player 1 lost
|
|
|
|
if self.is_win(board, player):
|
|
return 1
|
|
if self.is_win(board, -player):
|
|
return -1
|
|
if self.has_legal_moves(board):
|
|
return None
|
|
|
|
return 0
|
|
|
|
def get_canonical_board_data(self, board_data, player):
|
|
return player * board_data
|
|
|
|
def _small_board_mask(self, inner_board_idx):
|
|
mask = np.zeros((self.board_width, self.board_width), dtype=np.float32)
|
|
for index in self._small_board_cells(inner_board_idx):
|
|
row = index // self.board_width
|
|
col = index % self.board_width
|
|
mask[row, col] = 1.0
|
|
return mask
|
|
|
|
def encode_state(self, board):
|
|
board_data, active_board = board
|
|
board_grid = board_data.reshape(self.board_width, self.board_width)
|
|
|
|
current_stones = (board_grid == 1).astype(np.float32)
|
|
opponent_stones = (board_grid == -1).astype(np.float32)
|
|
empty_cells = (board_grid == 0).astype(np.float32)
|
|
legal_moves = np.array(self.get_valid_moves(board), dtype=np.float32).reshape(self.board_width, self.board_width)
|
|
|
|
active_board_mask = np.zeros((self.board_width, self.board_width), dtype=np.float32)
|
|
if active_board is not None and self._is_playable_small_board(board_data, active_board):
|
|
active_board_mask = self._small_board_mask(active_board)
|
|
|
|
current_won_boards = np.zeros((self.board_width, self.board_width), dtype=np.float32)
|
|
opponent_won_boards = np.zeros((self.board_width, self.board_width), dtype=np.float32)
|
|
playable_boards = np.zeros((self.board_width, self.board_width), dtype=np.float32)
|
|
|
|
for inner_board_idx in range(9):
|
|
board_mask = self._small_board_mask(inner_board_idx)
|
|
if self._is_small_board_win(board_data, inner_board_idx, 1):
|
|
current_won_boards += board_mask
|
|
elif self._is_small_board_win(board_data, inner_board_idx, -1):
|
|
opponent_won_boards += board_mask
|
|
|
|
if self._is_playable_small_board(board_data, inner_board_idx):
|
|
playable_boards += board_mask
|
|
|
|
move_count = np.count_nonzero(board_data) / self.cells
|
|
move_count_plane = np.full((self.board_width, self.board_width), move_count, dtype=np.float32)
|
|
|
|
return np.stack(
|
|
(
|
|
current_stones,
|
|
opponent_stones,
|
|
empty_cells,
|
|
legal_moves,
|
|
active_board_mask,
|
|
current_won_boards,
|
|
opponent_won_boards,
|
|
playable_boards,
|
|
move_count_plane,
|
|
),
|
|
axis=0,
|
|
)
|