play-with-o1 / src /o1 /agent.py
FlameF0X's picture
Upload 6 files
00ab121 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import chess
import random
import math
class SEBlock(nn.Module):
def __init__(self, channels, reduction=16):
super().__init__()
self.fc1 = nn.Linear(channels, channels // reduction)
self.fc2 = nn.Linear(channels // reduction, channels)
def forward(self, x):
b, c, h, w = x.size()
y = x.view(b, c, -1).mean(dim=2)
y = F.relu(self.fc1(y))
y = torch.sigmoid(self.fc2(y))
y = y.view(b, c, 1, 1)
return x * y
class ResidualBlock(nn.Module):
def __init__(self, channels, dropout=0.2):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
self.se = SEBlock(channels)
self.dropout = nn.Dropout2d(dropout)
def forward(self, x):
residual = x
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = self.se(out)
out = self.dropout(out)
out += residual
return F.relu(out)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=64):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x):
# x: (batch, seq_len, d_model)
x = x + self.pe[:, :x.size(1), :]
return x
class ChessNet(nn.Module):
def __init__(self, input_channels=20, board_size=8, policy_size=4672, num_blocks=20, transformer_layers=2, nhead=8):
super().__init__()
self.board_size = board_size
self.conv_in = nn.Conv2d(input_channels, 256, kernel_size=3, padding=1)
self.bn_in = nn.BatchNorm2d(256)
self.res_blocks = nn.Sequential(*[ResidualBlock(256) for _ in range(num_blocks)])
# Transformer encoder
self.pos_encoder = PositionalEncoding(256, max_len=board_size*board_size)
encoder_layer = nn.TransformerEncoderLayer(d_model=256, nhead=nhead, dim_feedforward=512, batch_first=True)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
self.fc1 = nn.Linear(256 * board_size * board_size, 512)
self.ln_fc1 = nn.LayerNorm(512)
# Policy head
self.policy_head1 = nn.Linear(512, 256)
self.policy_head2 = nn.Linear(256, policy_size)
# Value head
self.value_head1 = nn.Linear(512, 128)
self.value_head2 = nn.Linear(128, 1)
def forward(self, x):
x = F.relu(self.bn_in(self.conv_in(x))) # (B, 256, 8, 8)
x = self.res_blocks(x) # (B, 256, 8, 8)
B, C, H, W = x.shape
x = x.view(B, C, H * W).permute(0, 2, 1) # (B, 64, 256)
x = self.pos_encoder(x) # (B, 64, 256)
x = self.transformer(x) # (B, 64, 256)
x = x.permute(0, 2, 1).contiguous().view(B, -1) # (B, 256*64)
x = F.relu(self.ln_fc1(self.fc1(x)))
# Policy head
policy = F.relu(self.policy_head1(x))
policy = self.policy_head2(policy)
# Value head
value = F.relu(self.value_head1(x))
value = torch.tanh(self.value_head2(value))
return policy, value
class Agent:
def __init__(self, device='cpu'):
self.device = device
self.model = ChessNet().to(device)
self.model.eval()
def board_to_tensor(self, board):
# 12x8x8 binary planes for piece types/colors
piece_map = board.piece_map()
tensor = np.zeros((17, 8, 8), dtype=np.float32)
for square, piece in piece_map.items():
idx = self.piece_to_index(piece)
row, col = divmod(square, 8)
tensor[idx, row, col] = 1
# Add castling rights (4 planes)
if board.has_kingside_castling_rights(chess.WHITE):
tensor[12, :, :] = 1
if board.has_queenside_castling_rights(chess.WHITE):
tensor[13, :, :] = 1
if board.has_kingside_castling_rights(chess.BLACK):
tensor[14, :, :] = 1
if board.has_queenside_castling_rights(chess.BLACK):
tensor[15, :, :] = 1
# Add move count (normalized, 1 plane)
tensor[16, :, :] = board.fullmove_number / 100.0
# Add en passant square (1 plane)
if board.ep_square is not None:
tensor = np.concatenate([tensor, np.zeros((1, 8, 8), dtype=np.float32)], axis=0)
row, col = divmod(board.ep_square, 8)
tensor[-1, row, col] = 1
else:
tensor = np.concatenate([tensor, np.zeros((1, 8, 8), dtype=np.float32)], axis=0)
# Add repetition count (1 plane, normalized)
rep_count = board.is_repetition(3) + board.is_repetition(2)
tensor = np.concatenate([tensor, np.full((1, 8, 8), rep_count / 3.0, dtype=np.float32)], axis=0)
# Add 50-move rule counter (1 plane, normalized)
tensor = np.concatenate([tensor, np.full((1, 8, 8), board.halfmove_clock / 100.0, dtype=np.float32)], axis=0)
return torch.tensor(tensor, device=self.device).unsqueeze(0)
def piece_to_index(self, piece):
# 0-5: white P,N,B,R,Q,K; 6-11: black P,N,B,R,Q,K
offset = 0 if piece.color == chess.WHITE else 6
piece_type_map = {
chess.PAWN: 0,
chess.KNIGHT: 1,
chess.BISHOP: 2,
chess.ROOK: 3,
chess.QUEEN: 4,
chess.KING: 5
}
return offset + piece_type_map[piece.piece_type]
def predict(self, board):
x = self.board_to_tensor(board)
with torch.no_grad():
policy_logits, value = self.model(x)
return policy_logits, value
def diffusion_sample(self, policy_logits, steps=10, noise_scale=1.0, schedule_type='linear'):
"""
Backward (denoising) diffusion process with a more complex schedule.
- Start from noise, iteratively denoise toward policy_logits.
- Supports linear and cosine schedules for noise reduction.
- Adds stochasticity at each step.
"""
orig = policy_logits.clone()
x = torch.randn_like(orig) * noise_scale
if schedule_type == 'cosine':
# Cosine schedule for noise reduction
alphas = [np.cos((i / steps) * np.pi / 2) for i in range(steps, 0, -1)]
else:
# Linear schedule
alphas = np.linspace(1.0, 0.0, steps+1)[1:]
for i, alpha in enumerate(alphas):
# Denoising: weighted average between x and orig
x = alpha * x + (1 - alpha) * orig
# Add decreasing noise for stochasticity
step_noise = torch.randn_like(x) * (noise_scale * (alpha ** 2) / 2)
x = x + step_noise
return x
def predict_with_diffusion(self, board, steps=10, noise_scale=1.0, schedule_type='linear'):
x = self.board_to_tensor(board)
with torch.no_grad():
policy_logits, value = self.model(x)
diffused_logits = self.diffusion_sample(policy_logits, steps=steps, noise_scale=noise_scale, schedule_type=schedule_type)
return diffused_logits, value
def encode_move(self, move):
"""Encode a chess.Move to an integer index (0-4671)."""
# UCI move encoding: from_square*64*73 + to_square*73 + promotion
# 73 possible promotions (no promotion + 4 for each pawn move)
from_sq = move.from_square
to_sq = move.to_square
promo = 0
if move.promotion:
# 1: knight, 2: bishop, 3: rook, 4: queen
promo = {chess.KNIGHT: 1, chess.BISHOP: 2, chess.ROOK: 3, chess.QUEEN: 4}[move.promotion]
return from_sq * 64 * 5 + to_sq * 5 + promo
def decode_move(self, idx, board):
"""Decode an integer index to a legal chess.Move for the given board."""
from_sq = idx // (64 * 5)
to_sq = (idx // 5) % 64
promo = idx % 5
promotion = None
if promo:
promotion = [None, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN][promo]
move = chess.Move(from_sq, to_sq, promotion=promotion)
if move in board.legal_moves:
return move
# If not legal, return a random legal move as fallback
return random.choice(list(board.legal_moves))