Spaces:
Build error
Build error
File size: 8,974 Bytes
8806ce1 00ab121 8806ce1 00ab121 8806ce1 00ab121 8806ce1 00ab121 8806ce1 00ab121 8806ce1 00ab121 8806ce1 00ab121 8806ce1 00ab121 8806ce1 00ab121 8806ce1 00ab121 8806ce1 00ab121 8806ce1 00ab121 8806ce1 00ab121 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import chess
import random
import math
class SEBlock(nn.Module):
def __init__(self, channels, reduction=16):
super().__init__()
self.fc1 = nn.Linear(channels, channels // reduction)
self.fc2 = nn.Linear(channels // reduction, channels)
def forward(self, x):
b, c, h, w = x.size()
y = x.view(b, c, -1).mean(dim=2)
y = F.relu(self.fc1(y))
y = torch.sigmoid(self.fc2(y))
y = y.view(b, c, 1, 1)
return x * y
class ResidualBlock(nn.Module):
def __init__(self, channels, dropout=0.2):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
self.se = SEBlock(channels)
self.dropout = nn.Dropout2d(dropout)
def forward(self, x):
residual = x
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = self.se(out)
out = self.dropout(out)
out += residual
return F.relu(out)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=64):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x):
# x: (batch, seq_len, d_model)
x = x + self.pe[:, :x.size(1), :]
return x
class ChessNet(nn.Module):
def __init__(self, input_channels=20, board_size=8, policy_size=4672, num_blocks=20, transformer_layers=2, nhead=8):
super().__init__()
self.board_size = board_size
self.conv_in = nn.Conv2d(input_channels, 256, kernel_size=3, padding=1)
self.bn_in = nn.BatchNorm2d(256)
self.res_blocks = nn.Sequential(*[ResidualBlock(256) for _ in range(num_blocks)])
# Transformer encoder
self.pos_encoder = PositionalEncoding(256, max_len=board_size*board_size)
encoder_layer = nn.TransformerEncoderLayer(d_model=256, nhead=nhead, dim_feedforward=512, batch_first=True)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
self.fc1 = nn.Linear(256 * board_size * board_size, 512)
self.ln_fc1 = nn.LayerNorm(512)
# Policy head
self.policy_head1 = nn.Linear(512, 256)
self.policy_head2 = nn.Linear(256, policy_size)
# Value head
self.value_head1 = nn.Linear(512, 128)
self.value_head2 = nn.Linear(128, 1)
def forward(self, x):
x = F.relu(self.bn_in(self.conv_in(x))) # (B, 256, 8, 8)
x = self.res_blocks(x) # (B, 256, 8, 8)
B, C, H, W = x.shape
x = x.view(B, C, H * W).permute(0, 2, 1) # (B, 64, 256)
x = self.pos_encoder(x) # (B, 64, 256)
x = self.transformer(x) # (B, 64, 256)
x = x.permute(0, 2, 1).contiguous().view(B, -1) # (B, 256*64)
x = F.relu(self.ln_fc1(self.fc1(x)))
# Policy head
policy = F.relu(self.policy_head1(x))
policy = self.policy_head2(policy)
# Value head
value = F.relu(self.value_head1(x))
value = torch.tanh(self.value_head2(value))
return policy, value
class Agent:
def __init__(self, device='cpu'):
self.device = device
self.model = ChessNet().to(device)
self.model.eval()
def board_to_tensor(self, board):
# 12x8x8 binary planes for piece types/colors
piece_map = board.piece_map()
tensor = np.zeros((17, 8, 8), dtype=np.float32)
for square, piece in piece_map.items():
idx = self.piece_to_index(piece)
row, col = divmod(square, 8)
tensor[idx, row, col] = 1
# Add castling rights (4 planes)
if board.has_kingside_castling_rights(chess.WHITE):
tensor[12, :, :] = 1
if board.has_queenside_castling_rights(chess.WHITE):
tensor[13, :, :] = 1
if board.has_kingside_castling_rights(chess.BLACK):
tensor[14, :, :] = 1
if board.has_queenside_castling_rights(chess.BLACK):
tensor[15, :, :] = 1
# Add move count (normalized, 1 plane)
tensor[16, :, :] = board.fullmove_number / 100.0
# Add en passant square (1 plane)
if board.ep_square is not None:
tensor = np.concatenate([tensor, np.zeros((1, 8, 8), dtype=np.float32)], axis=0)
row, col = divmod(board.ep_square, 8)
tensor[-1, row, col] = 1
else:
tensor = np.concatenate([tensor, np.zeros((1, 8, 8), dtype=np.float32)], axis=0)
# Add repetition count (1 plane, normalized)
rep_count = board.is_repetition(3) + board.is_repetition(2)
tensor = np.concatenate([tensor, np.full((1, 8, 8), rep_count / 3.0, dtype=np.float32)], axis=0)
# Add 50-move rule counter (1 plane, normalized)
tensor = np.concatenate([tensor, np.full((1, 8, 8), board.halfmove_clock / 100.0, dtype=np.float32)], axis=0)
return torch.tensor(tensor, device=self.device).unsqueeze(0)
def piece_to_index(self, piece):
# 0-5: white P,N,B,R,Q,K; 6-11: black P,N,B,R,Q,K
offset = 0 if piece.color == chess.WHITE else 6
piece_type_map = {
chess.PAWN: 0,
chess.KNIGHT: 1,
chess.BISHOP: 2,
chess.ROOK: 3,
chess.QUEEN: 4,
chess.KING: 5
}
return offset + piece_type_map[piece.piece_type]
def predict(self, board):
x = self.board_to_tensor(board)
with torch.no_grad():
policy_logits, value = self.model(x)
return policy_logits, value
def diffusion_sample(self, policy_logits, steps=10, noise_scale=1.0, schedule_type='linear'):
"""
Backward (denoising) diffusion process with a more complex schedule.
- Start from noise, iteratively denoise toward policy_logits.
- Supports linear and cosine schedules for noise reduction.
- Adds stochasticity at each step.
"""
orig = policy_logits.clone()
x = torch.randn_like(orig) * noise_scale
if schedule_type == 'cosine':
# Cosine schedule for noise reduction
alphas = [np.cos((i / steps) * np.pi / 2) for i in range(steps, 0, -1)]
else:
# Linear schedule
alphas = np.linspace(1.0, 0.0, steps+1)[1:]
for i, alpha in enumerate(alphas):
# Denoising: weighted average between x and orig
x = alpha * x + (1 - alpha) * orig
# Add decreasing noise for stochasticity
step_noise = torch.randn_like(x) * (noise_scale * (alpha ** 2) / 2)
x = x + step_noise
return x
def predict_with_diffusion(self, board, steps=10, noise_scale=1.0, schedule_type='linear'):
x = self.board_to_tensor(board)
with torch.no_grad():
policy_logits, value = self.model(x)
diffused_logits = self.diffusion_sample(policy_logits, steps=steps, noise_scale=noise_scale, schedule_type=schedule_type)
return diffused_logits, value
def encode_move(self, move):
"""Encode a chess.Move to an integer index (0-4671)."""
# UCI move encoding: from_square*64*73 + to_square*73 + promotion
# 73 possible promotions (no promotion + 4 for each pawn move)
from_sq = move.from_square
to_sq = move.to_square
promo = 0
if move.promotion:
# 1: knight, 2: bishop, 3: rook, 4: queen
promo = {chess.KNIGHT: 1, chess.BISHOP: 2, chess.ROOK: 3, chess.QUEEN: 4}[move.promotion]
return from_sq * 64 * 5 + to_sq * 5 + promo
def decode_move(self, idx, board):
"""Decode an integer index to a legal chess.Move for the given board."""
from_sq = idx // (64 * 5)
to_sq = (idx // 5) % 64
promo = idx % 5
promotion = None
if promo:
promotion = [None, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN][promo]
move = chess.Move(from_sq, to_sq, promotion=promotion)
if move in board.legal_moves:
return move
# If not legal, return a random legal move as fallback
return random.choice(list(board.legal_moves))
|