import chess import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class O2Net(nn.Module): def __init__(self): super(O2Net, self).__init__() # Input layer (updated to 1152 for 8x8x18 encoding) self.input_fc = nn.Linear(1152, 1024) # 10 deep residual blocks self.res_blocks = nn.ModuleList([ nn.Sequential( nn.Linear(1024, 1024), nn.BatchNorm1d(1024), nn.ReLU(), nn.Linear(1024, 1024), nn.BatchNorm1d(1024) ) for _ in range(10) ]) self.res_relu = nn.ReLU() # Policy head self.policy_fc1 = nn.Linear(1024, 512) self.policy_fc2 = nn.Linear(512, 256) self.policy_fc3 = nn.Linear(256, 4672) # Value head self.value_fc1 = nn.Linear(1024, 512) self.value_fc2 = nn.Linear(512, 128) self.value_fc3 = nn.Linear(128, 1) def forward(self, x): x = F.relu(self.input_fc(x)) for block in self.res_blocks: residual = x out = block(x) x = self.res_relu(out + residual) # Policy head p = F.relu(self.policy_fc1(x)) p = F.relu(self.policy_fc2(p)) policy = self.policy_fc3(p) # Value head v = F.relu(self.value_fc1(x)) v = F.relu(self.value_fc2(v)) value = torch.tanh(self.value_fc3(v)) return policy, value def board_to_tensor(board): # Improved encoding: 8x8x18 planes (12 for pieces, 6 for state), flattened # 12 planes: one for each piece type/color # 6 planes: turn, castling rights (4), en passant planes = np.zeros((18, 8, 8), dtype=np.float32) piece_map = board.piece_map() for square, piece in piece_map.items(): plane = (piece.piece_type - 1) + (0 if piece.color == chess.WHITE else 6) row, col = divmod(square, 8) planes[plane, row, col] = 1 # Turn plane planes[12, :, :] = int(board.turn) # Castling rights planes[13, :, :] = int(board.has_kingside_castling_rights(chess.WHITE)) planes[14, :, :] = int(board.has_queenside_castling_rights(chess.WHITE)) planes[15, :, :] = int(board.has_kingside_castling_rights(chess.BLACK)) planes[16, :, :] = int(board.has_queenside_castling_rights(chess.BLACK)) # En passant if board.ep_square is not None: row, col = divmod(board.ep_square, 8) planes[17, row, col] = 1 return planes.flatten() if __name__ == "__main__": board = chess.Board() net = O2Net() x = torch.tensor(board_to_tensor(board)).unsqueeze(0) policy, value = net(x) print("Policy shape:", policy.shape) print("Value:", value.item())