File size: 8,974 Bytes
8806ce1
 
 
 
 
00ab121
 
8806ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00ab121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8806ce1
00ab121
8806ce1
00ab121
8806ce1
 
 
00ab121
 
 
 
8806ce1
 
 
 
 
 
 
 
 
 
00ab121
 
 
 
 
 
 
8806ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00ab121
 
 
 
 
 
 
 
 
 
 
 
8806ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00ab121
8806ce1
00ab121
 
 
 
8806ce1
 
00ab121
 
 
 
 
 
 
 
 
 
 
 
 
8806ce1
 
00ab121
8806ce1
 
 
00ab121
8806ce1
00ab121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import chess
import random
import math

class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.fc2 = nn.Linear(channels // reduction, channels)

    def forward(self, x):
        b, c, h, w = x.size()
        y = x.view(b, c, -1).mean(dim=2)
        y = F.relu(self.fc1(y))
        y = torch.sigmoid(self.fc2(y))
        y = y.view(b, c, 1, 1)
        return x * y

class ResidualBlock(nn.Module):
    def __init__(self, channels, dropout=0.2):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        self.se = SEBlock(channels)
        self.dropout = nn.Dropout2d(dropout)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.se(out)
        out = self.dropout(out)
        out += residual
        return F.relu(out)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=64):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        x = x + self.pe[:, :x.size(1), :]
        return x

class ChessNet(nn.Module):
    def __init__(self, input_channels=20, board_size=8, policy_size=4672, num_blocks=20, transformer_layers=2, nhead=8):
        super().__init__()
        self.board_size = board_size
        self.conv_in = nn.Conv2d(input_channels, 256, kernel_size=3, padding=1)
        self.bn_in = nn.BatchNorm2d(256)
        self.res_blocks = nn.Sequential(*[ResidualBlock(256) for _ in range(num_blocks)])
        # Transformer encoder
        self.pos_encoder = PositionalEncoding(256, max_len=board_size*board_size)
        encoder_layer = nn.TransformerEncoderLayer(d_model=256, nhead=nhead, dim_feedforward=512, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
        self.fc1 = nn.Linear(256 * board_size * board_size, 512)
        self.ln_fc1 = nn.LayerNorm(512)
        # Policy head
        self.policy_head1 = nn.Linear(512, 256)
        self.policy_head2 = nn.Linear(256, policy_size)
        # Value head
        self.value_head1 = nn.Linear(512, 128)
        self.value_head2 = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.bn_in(self.conv_in(x)))  # (B, 256, 8, 8)
        x = self.res_blocks(x)  # (B, 256, 8, 8)
        B, C, H, W = x.shape
        x = x.view(B, C, H * W).permute(0, 2, 1)  # (B, 64, 256)
        x = self.pos_encoder(x)  # (B, 64, 256)
        x = self.transformer(x)  # (B, 64, 256)
        x = x.permute(0, 2, 1).contiguous().view(B, -1)  # (B, 256*64)
        x = F.relu(self.ln_fc1(self.fc1(x)))
        # Policy head
        policy = F.relu(self.policy_head1(x))
        policy = self.policy_head2(policy)
        # Value head
        value = F.relu(self.value_head1(x))
        value = torch.tanh(self.value_head2(value))
        return policy, value

class Agent:
    def __init__(self, device='cpu'):
        self.device = device
        self.model = ChessNet().to(device)
        self.model.eval()

    def board_to_tensor(self, board):
        # 12x8x8 binary planes for piece types/colors
        piece_map = board.piece_map()
        tensor = np.zeros((17, 8, 8), dtype=np.float32)
        for square, piece in piece_map.items():
            idx = self.piece_to_index(piece)
            row, col = divmod(square, 8)
            tensor[idx, row, col] = 1
        # Add castling rights (4 planes)
        if board.has_kingside_castling_rights(chess.WHITE):
            tensor[12, :, :] = 1
        if board.has_queenside_castling_rights(chess.WHITE):
            tensor[13, :, :] = 1
        if board.has_kingside_castling_rights(chess.BLACK):
            tensor[14, :, :] = 1
        if board.has_queenside_castling_rights(chess.BLACK):
            tensor[15, :, :] = 1
        # Add move count (normalized, 1 plane)
        tensor[16, :, :] = board.fullmove_number / 100.0
        # Add en passant square (1 plane)
        if board.ep_square is not None:
            tensor = np.concatenate([tensor, np.zeros((1, 8, 8), dtype=np.float32)], axis=0)
            row, col = divmod(board.ep_square, 8)
            tensor[-1, row, col] = 1
        else:
            tensor = np.concatenate([tensor, np.zeros((1, 8, 8), dtype=np.float32)], axis=0)
        # Add repetition count (1 plane, normalized)
        rep_count = board.is_repetition(3) + board.is_repetition(2)
        tensor = np.concatenate([tensor, np.full((1, 8, 8), rep_count / 3.0, dtype=np.float32)], axis=0)
        # Add 50-move rule counter (1 plane, normalized)
        tensor = np.concatenate([tensor, np.full((1, 8, 8), board.halfmove_clock / 100.0, dtype=np.float32)], axis=0)
        return torch.tensor(tensor, device=self.device).unsqueeze(0)

    def piece_to_index(self, piece):
        # 0-5: white P,N,B,R,Q,K; 6-11: black P,N,B,R,Q,K
        offset = 0 if piece.color == chess.WHITE else 6
        piece_type_map = {
            chess.PAWN: 0,
            chess.KNIGHT: 1,
            chess.BISHOP: 2,
            chess.ROOK: 3,
            chess.QUEEN: 4,
            chess.KING: 5
        }
        return offset + piece_type_map[piece.piece_type]

    def predict(self, board):
        x = self.board_to_tensor(board)
        with torch.no_grad():
            policy_logits, value = self.model(x)
        return policy_logits, value

    def diffusion_sample(self, policy_logits, steps=10, noise_scale=1.0, schedule_type='linear'):
        """

        Backward (denoising) diffusion process with a more complex schedule.

        - Start from noise, iteratively denoise toward policy_logits.

        - Supports linear and cosine schedules for noise reduction.

        - Adds stochasticity at each step.

        """
        orig = policy_logits.clone()
        x = torch.randn_like(orig) * noise_scale
        if schedule_type == 'cosine':
            # Cosine schedule for noise reduction
            alphas = [np.cos((i / steps) * np.pi / 2) for i in range(steps, 0, -1)]
        else:
            # Linear schedule
            alphas = np.linspace(1.0, 0.0, steps+1)[1:]
        for i, alpha in enumerate(alphas):
            # Denoising: weighted average between x and orig
            x = alpha * x + (1 - alpha) * orig
            # Add decreasing noise for stochasticity
            step_noise = torch.randn_like(x) * (noise_scale * (alpha ** 2) / 2)
            x = x + step_noise
        return x

    def predict_with_diffusion(self, board, steps=10, noise_scale=1.0, schedule_type='linear'):
        x = self.board_to_tensor(board)
        with torch.no_grad():
            policy_logits, value = self.model(x)
            diffused_logits = self.diffusion_sample(policy_logits, steps=steps, noise_scale=noise_scale, schedule_type=schedule_type)
        return diffused_logits, value

    def encode_move(self, move):
        """Encode a chess.Move to an integer index (0-4671)."""
        # UCI move encoding: from_square*64*73 + to_square*73 + promotion
        # 73 possible promotions (no promotion + 4 for each pawn move)
        from_sq = move.from_square
        to_sq = move.to_square
        promo = 0
        if move.promotion:
            # 1: knight, 2: bishop, 3: rook, 4: queen
            promo = {chess.KNIGHT: 1, chess.BISHOP: 2, chess.ROOK: 3, chess.QUEEN: 4}[move.promotion]
        return from_sq * 64 * 5 + to_sq * 5 + promo

    def decode_move(self, idx, board):
        """Decode an integer index to a legal chess.Move for the given board."""
        from_sq = idx // (64 * 5)
        to_sq = (idx // 5) % 64
        promo = idx % 5
        promotion = None
        if promo:
            promotion = [None, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN][promo]
        move = chess.Move(from_sq, to_sq, promotion=promotion)
        if move in board.legal_moves:
            return move
        # If not legal, return a random legal move as fallback
        return random.choice(list(board.legal_moves))