Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
import chess | |
import random | |
import math | |
class SEBlock(nn.Module): | |
def __init__(self, channels, reduction=16): | |
super().__init__() | |
self.fc1 = nn.Linear(channels, channels // reduction) | |
self.fc2 = nn.Linear(channels // reduction, channels) | |
def forward(self, x): | |
b, c, h, w = x.size() | |
y = x.view(b, c, -1).mean(dim=2) | |
y = F.relu(self.fc1(y)) | |
y = torch.sigmoid(self.fc2(y)) | |
y = y.view(b, c, 1, 1) | |
return x * y | |
class ResidualBlock(nn.Module): | |
def __init__(self, channels, dropout=0.2): | |
super().__init__() | |
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) | |
self.bn1 = nn.BatchNorm2d(channels) | |
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) | |
self.bn2 = nn.BatchNorm2d(channels) | |
self.se = SEBlock(channels) | |
self.dropout = nn.Dropout2d(dropout) | |
def forward(self, x): | |
residual = x | |
out = F.relu(self.bn1(self.conv1(x))) | |
out = self.bn2(self.conv2(out)) | |
out = self.se(out) | |
out = self.dropout(out) | |
out += residual | |
return F.relu(out) | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, max_len=64): | |
super().__init__() | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0) # (1, max_len, d_model) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
# x: (batch, seq_len, d_model) | |
x = x + self.pe[:, :x.size(1), :] | |
return x | |
class ChessNet(nn.Module): | |
def __init__(self, input_channels=20, board_size=8, policy_size=4672, num_blocks=20, transformer_layers=2, nhead=8): | |
super().__init__() | |
self.board_size = board_size | |
self.conv_in = nn.Conv2d(input_channels, 256, kernel_size=3, padding=1) | |
self.bn_in = nn.BatchNorm2d(256) | |
self.res_blocks = nn.Sequential(*[ResidualBlock(256) for _ in range(num_blocks)]) | |
# Transformer encoder | |
self.pos_encoder = PositionalEncoding(256, max_len=board_size*board_size) | |
encoder_layer = nn.TransformerEncoderLayer(d_model=256, nhead=nhead, dim_feedforward=512, batch_first=True) | |
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers) | |
self.fc1 = nn.Linear(256 * board_size * board_size, 512) | |
self.ln_fc1 = nn.LayerNorm(512) | |
# Policy head | |
self.policy_head1 = nn.Linear(512, 256) | |
self.policy_head2 = nn.Linear(256, policy_size) | |
# Value head | |
self.value_head1 = nn.Linear(512, 128) | |
self.value_head2 = nn.Linear(128, 1) | |
def forward(self, x): | |
x = F.relu(self.bn_in(self.conv_in(x))) # (B, 256, 8, 8) | |
x = self.res_blocks(x) # (B, 256, 8, 8) | |
B, C, H, W = x.shape | |
x = x.view(B, C, H * W).permute(0, 2, 1) # (B, 64, 256) | |
x = self.pos_encoder(x) # (B, 64, 256) | |
x = self.transformer(x) # (B, 64, 256) | |
x = x.permute(0, 2, 1).contiguous().view(B, -1) # (B, 256*64) | |
x = F.relu(self.ln_fc1(self.fc1(x))) | |
# Policy head | |
policy = F.relu(self.policy_head1(x)) | |
policy = self.policy_head2(policy) | |
# Value head | |
value = F.relu(self.value_head1(x)) | |
value = torch.tanh(self.value_head2(value)) | |
return policy, value | |
class Agent: | |
def __init__(self, device='cpu'): | |
self.device = device | |
self.model = ChessNet().to(device) | |
self.model.eval() | |
def board_to_tensor(self, board): | |
# 12x8x8 binary planes for piece types/colors | |
piece_map = board.piece_map() | |
tensor = np.zeros((17, 8, 8), dtype=np.float32) | |
for square, piece in piece_map.items(): | |
idx = self.piece_to_index(piece) | |
row, col = divmod(square, 8) | |
tensor[idx, row, col] = 1 | |
# Add castling rights (4 planes) | |
if board.has_kingside_castling_rights(chess.WHITE): | |
tensor[12, :, :] = 1 | |
if board.has_queenside_castling_rights(chess.WHITE): | |
tensor[13, :, :] = 1 | |
if board.has_kingside_castling_rights(chess.BLACK): | |
tensor[14, :, :] = 1 | |
if board.has_queenside_castling_rights(chess.BLACK): | |
tensor[15, :, :] = 1 | |
# Add move count (normalized, 1 plane) | |
tensor[16, :, :] = board.fullmove_number / 100.0 | |
# Add en passant square (1 plane) | |
if board.ep_square is not None: | |
tensor = np.concatenate([tensor, np.zeros((1, 8, 8), dtype=np.float32)], axis=0) | |
row, col = divmod(board.ep_square, 8) | |
tensor[-1, row, col] = 1 | |
else: | |
tensor = np.concatenate([tensor, np.zeros((1, 8, 8), dtype=np.float32)], axis=0) | |
# Add repetition count (1 plane, normalized) | |
rep_count = board.is_repetition(3) + board.is_repetition(2) | |
tensor = np.concatenate([tensor, np.full((1, 8, 8), rep_count / 3.0, dtype=np.float32)], axis=0) | |
# Add 50-move rule counter (1 plane, normalized) | |
tensor = np.concatenate([tensor, np.full((1, 8, 8), board.halfmove_clock / 100.0, dtype=np.float32)], axis=0) | |
return torch.tensor(tensor, device=self.device).unsqueeze(0) | |
def piece_to_index(self, piece): | |
# 0-5: white P,N,B,R,Q,K; 6-11: black P,N,B,R,Q,K | |
offset = 0 if piece.color == chess.WHITE else 6 | |
piece_type_map = { | |
chess.PAWN: 0, | |
chess.KNIGHT: 1, | |
chess.BISHOP: 2, | |
chess.ROOK: 3, | |
chess.QUEEN: 4, | |
chess.KING: 5 | |
} | |
return offset + piece_type_map[piece.piece_type] | |
def predict(self, board): | |
x = self.board_to_tensor(board) | |
with torch.no_grad(): | |
policy_logits, value = self.model(x) | |
return policy_logits, value | |
def diffusion_sample(self, policy_logits, steps=10, noise_scale=1.0, schedule_type='linear'): | |
""" | |
Backward (denoising) diffusion process with a more complex schedule. | |
- Start from noise, iteratively denoise toward policy_logits. | |
- Supports linear and cosine schedules for noise reduction. | |
- Adds stochasticity at each step. | |
""" | |
orig = policy_logits.clone() | |
x = torch.randn_like(orig) * noise_scale | |
if schedule_type == 'cosine': | |
# Cosine schedule for noise reduction | |
alphas = [np.cos((i / steps) * np.pi / 2) for i in range(steps, 0, -1)] | |
else: | |
# Linear schedule | |
alphas = np.linspace(1.0, 0.0, steps+1)[1:] | |
for i, alpha in enumerate(alphas): | |
# Denoising: weighted average between x and orig | |
x = alpha * x + (1 - alpha) * orig | |
# Add decreasing noise for stochasticity | |
step_noise = torch.randn_like(x) * (noise_scale * (alpha ** 2) / 2) | |
x = x + step_noise | |
return x | |
def predict_with_diffusion(self, board, steps=10, noise_scale=1.0, schedule_type='linear'): | |
x = self.board_to_tensor(board) | |
with torch.no_grad(): | |
policy_logits, value = self.model(x) | |
diffused_logits = self.diffusion_sample(policy_logits, steps=steps, noise_scale=noise_scale, schedule_type=schedule_type) | |
return diffused_logits, value | |
def encode_move(self, move): | |
"""Encode a chess.Move to an integer index (0-4671).""" | |
# UCI move encoding: from_square*64*73 + to_square*73 + promotion | |
# 73 possible promotions (no promotion + 4 for each pawn move) | |
from_sq = move.from_square | |
to_sq = move.to_square | |
promo = 0 | |
if move.promotion: | |
# 1: knight, 2: bishop, 3: rook, 4: queen | |
promo = {chess.KNIGHT: 1, chess.BISHOP: 2, chess.ROOK: 3, chess.QUEEN: 4}[move.promotion] | |
return from_sq * 64 * 5 + to_sq * 5 + promo | |
def decode_move(self, idx, board): | |
"""Decode an integer index to a legal chess.Move for the given board.""" | |
from_sq = idx // (64 * 5) | |
to_sq = (idx // 5) % 64 | |
promo = idx % 5 | |
promotion = None | |
if promo: | |
promotion = [None, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN][promo] | |
move = chess.Move(from_sq, to_sq, promotion=promotion) | |
if move in board.legal_moves: | |
return move | |
# If not legal, return a random legal move as fallback | |
return random.choice(list(board.legal_moves)) | |