Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files- src/mcts.py +88 -0
- src/o2_agent.py +78 -0
- src/o2_model.py +77 -0
src/mcts.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import chess
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from o2_model import board_to_tensor
|
5 |
+
|
6 |
+
class MCTSNode:
|
7 |
+
def __init__(self, board, parent=None, move=None):
|
8 |
+
self.board = board.copy()
|
9 |
+
self.parent = parent
|
10 |
+
self.move = move
|
11 |
+
self.children = {}
|
12 |
+
self.N = 0 # Visit count
|
13 |
+
self.W = 0 # Total value
|
14 |
+
self.Q = 0 # Mean value
|
15 |
+
self.P = 0 # Prior probability
|
16 |
+
|
17 |
+
class MCTS:
|
18 |
+
def __init__(self, model, simulations=100, c_puct=1.5):
|
19 |
+
self.model = model
|
20 |
+
self.simulations = simulations
|
21 |
+
self.c_puct = c_puct
|
22 |
+
|
23 |
+
def run(self, board):
|
24 |
+
root = MCTSNode(board)
|
25 |
+
self._expand(root)
|
26 |
+
for _ in range(self.simulations):
|
27 |
+
node = root
|
28 |
+
search_path = [node]
|
29 |
+
# Selection
|
30 |
+
while node.children:
|
31 |
+
max_ucb = -float('inf')
|
32 |
+
best_move = None
|
33 |
+
for move, child in node.children.items():
|
34 |
+
ucb = child.Q + self.c_puct * child.P * np.sqrt(node.N) / (1 + child.N)
|
35 |
+
if ucb > max_ucb:
|
36 |
+
max_ucb = ucb
|
37 |
+
best_move = move
|
38 |
+
node = node.children[best_move]
|
39 |
+
search_path.append(node)
|
40 |
+
# Expansion
|
41 |
+
value = self._expand(node)
|
42 |
+
# Backpropagation
|
43 |
+
for n in reversed(search_path):
|
44 |
+
n.N += 1
|
45 |
+
n.W += value
|
46 |
+
n.Q = n.W / n.N
|
47 |
+
value = -value # Switch perspective
|
48 |
+
# Choose move with highest visit count
|
49 |
+
best_move = max(root.children.items(), key=lambda item: item[1].N)[0]
|
50 |
+
return best_move
|
51 |
+
|
52 |
+
def _expand(self, node):
|
53 |
+
if node.board.is_game_over():
|
54 |
+
result = node.board.result()
|
55 |
+
if result == '1-0':
|
56 |
+
return 1
|
57 |
+
elif result == '0-1':
|
58 |
+
return -1
|
59 |
+
else:
|
60 |
+
return 0
|
61 |
+
tensor = torch.tensor(board_to_tensor(node.board)).unsqueeze(0)
|
62 |
+
with torch.no_grad():
|
63 |
+
policy, value = self.model(tensor)
|
64 |
+
policy = torch.softmax(policy, dim=1).numpy()[0]
|
65 |
+
legal_moves = list(node.board.legal_moves)
|
66 |
+
total_p = 1e-8
|
67 |
+
for move in legal_moves:
|
68 |
+
idx = self.move_to_index(move)
|
69 |
+
p = policy[idx]
|
70 |
+
total_p += p
|
71 |
+
for move in legal_moves:
|
72 |
+
idx = self.move_to_index(move)
|
73 |
+
p = policy[idx] / total_p
|
74 |
+
child_board = node.board.copy()
|
75 |
+
child_board.push(move)
|
76 |
+
child = MCTSNode(child_board, parent=node, move=move)
|
77 |
+
child.P = p
|
78 |
+
node.children[move] = child
|
79 |
+
return value.item()
|
80 |
+
|
81 |
+
def move_to_index(self, move):
|
82 |
+
from_square = move.from_square
|
83 |
+
to_square = move.to_square
|
84 |
+
promotion = move.promotion if move.promotion else 0
|
85 |
+
promotion_offset = 0
|
86 |
+
if promotion:
|
87 |
+
promotion_offset = 4096 + (promotion - 1)
|
88 |
+
return from_square * 64 + to_square + promotion_offset
|
src/o2_agent.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import chess
|
2 |
+
import torch
|
3 |
+
from o2_model import O2Net, board_to_tensor
|
4 |
+
from mcts import MCTS
|
5 |
+
import random
|
6 |
+
|
7 |
+
# Optional: Endgame tablebase and opening book integration placeholders
|
8 |
+
# You can use python-chess's tablebase and opening book modules if desired
|
9 |
+
# Example for endgame tablebase:
|
10 |
+
# from chess import tablebase
|
11 |
+
# tb = tablebase.Tablebase()
|
12 |
+
# tb.add_tablebase('/path/to/syzygy')
|
13 |
+
# if tb.probe_wdl(board) is not None:
|
14 |
+
# # Use tablebase move
|
15 |
+
# Example for opening book:
|
16 |
+
# from chess.polyglot import open_reader
|
17 |
+
# with open_reader('book.bin') as reader:
|
18 |
+
# entry = reader.find(board)
|
19 |
+
# move = entry.move
|
20 |
+
|
21 |
+
class O2Agent:
|
22 |
+
def __init__(self, model_path=None):
|
23 |
+
self.model = O2Net()
|
24 |
+
if model_path:
|
25 |
+
self.model.load_state_dict(torch.load(model_path))
|
26 |
+
self.model.eval()
|
27 |
+
|
28 |
+
def select_move(self, board, use_mcts=True, simulations=100):
|
29 |
+
if use_mcts:
|
30 |
+
mcts = MCTS(self.model, simulations=simulations)
|
31 |
+
return mcts.run(board)
|
32 |
+
tensor = torch.tensor(board_to_tensor(board)).unsqueeze(0)
|
33 |
+
with torch.no_grad():
|
34 |
+
policy, _ = self.model(tensor)
|
35 |
+
legal_moves = list(board.legal_moves)
|
36 |
+
move_scores = []
|
37 |
+
for move in legal_moves:
|
38 |
+
move_idx = self.move_to_index(move)
|
39 |
+
move_scores.append(policy[0, move_idx].item())
|
40 |
+
best_move = legal_moves[int(torch.tensor(move_scores).argmax())]
|
41 |
+
return best_move
|
42 |
+
|
43 |
+
def move_to_index(self, move):
|
44 |
+
# Encode move as from_square * 64 + to_square + promotion_offset
|
45 |
+
from_square = move.from_square
|
46 |
+
to_square = move.to_square
|
47 |
+
promotion = move.promotion if move.promotion else 0
|
48 |
+
promotion_offset = 0
|
49 |
+
if promotion:
|
50 |
+
# Promotion: 1=Knight, 2=Bishop, 3=Rook, 4=Queen (python-chess)
|
51 |
+
# Offset: 4096 + (promotion-1)*64*64//4
|
52 |
+
promotion_offset = 4096 + (promotion - 1) * 256
|
53 |
+
idx = from_square * 64 + to_square + promotion_offset
|
54 |
+
# Ensure index is within bounds
|
55 |
+
return idx if idx < 4672 else idx % 4672
|
56 |
+
|
57 |
+
def index_to_move(self, board, index):
|
58 |
+
# Decode index to move (reverse of move_to_index)
|
59 |
+
if index >= 4096:
|
60 |
+
promotion = (index - 4096) % 4 + 1
|
61 |
+
idx = index - 4096
|
62 |
+
from_square = idx // 64
|
63 |
+
to_square = idx % 64
|
64 |
+
move = chess.Move(from_square, to_square, promotion=promotion)
|
65 |
+
else:
|
66 |
+
from_square = index // 64
|
67 |
+
to_square = index % 64
|
68 |
+
move = chess.Move(from_square, to_square)
|
69 |
+
if move in board.legal_moves:
|
70 |
+
return move
|
71 |
+
# Fallback: pick a random legal move
|
72 |
+
return random.choice(list(board.legal_moves))
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
board = chess.Board()
|
76 |
+
agent = O2Agent()
|
77 |
+
move = agent.select_move(board)
|
78 |
+
print("O2 selects:", move)
|
src/o2_model.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import chess
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
class O2Net(nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super(O2Net, self).__init__()
|
10 |
+
# Input layer (updated to 1152 for 8x8x18 encoding)
|
11 |
+
self.input_fc = nn.Linear(1152, 1024)
|
12 |
+
# 10 deep residual blocks
|
13 |
+
self.res_blocks = nn.ModuleList([
|
14 |
+
nn.Sequential(
|
15 |
+
nn.Linear(1024, 1024),
|
16 |
+
nn.BatchNorm1d(1024),
|
17 |
+
nn.ReLU(),
|
18 |
+
nn.Linear(1024, 1024),
|
19 |
+
nn.BatchNorm1d(1024)
|
20 |
+
) for _ in range(10)
|
21 |
+
])
|
22 |
+
self.res_relu = nn.ReLU()
|
23 |
+
# Policy head
|
24 |
+
self.policy_fc1 = nn.Linear(1024, 512)
|
25 |
+
self.policy_fc2 = nn.Linear(512, 256)
|
26 |
+
self.policy_fc3 = nn.Linear(256, 4672)
|
27 |
+
# Value head
|
28 |
+
self.value_fc1 = nn.Linear(1024, 512)
|
29 |
+
self.value_fc2 = nn.Linear(512, 128)
|
30 |
+
self.value_fc3 = nn.Linear(128, 1)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x = F.relu(self.input_fc(x))
|
34 |
+
for block in self.res_blocks:
|
35 |
+
residual = x
|
36 |
+
out = block(x)
|
37 |
+
x = self.res_relu(out + residual)
|
38 |
+
# Policy head
|
39 |
+
p = F.relu(self.policy_fc1(x))
|
40 |
+
p = F.relu(self.policy_fc2(p))
|
41 |
+
policy = self.policy_fc3(p)
|
42 |
+
# Value head
|
43 |
+
v = F.relu(self.value_fc1(x))
|
44 |
+
v = F.relu(self.value_fc2(v))
|
45 |
+
value = torch.tanh(self.value_fc3(v))
|
46 |
+
return policy, value
|
47 |
+
|
48 |
+
def board_to_tensor(board):
|
49 |
+
# Improved encoding: 8x8x18 planes (12 for pieces, 6 for state), flattened
|
50 |
+
# 12 planes: one for each piece type/color
|
51 |
+
# 6 planes: turn, castling rights (4), en passant
|
52 |
+
planes = np.zeros((18, 8, 8), dtype=np.float32)
|
53 |
+
piece_map = board.piece_map()
|
54 |
+
for square, piece in piece_map.items():
|
55 |
+
plane = (piece.piece_type - 1) + (0 if piece.color == chess.WHITE else 6)
|
56 |
+
row, col = divmod(square, 8)
|
57 |
+
planes[plane, row, col] = 1
|
58 |
+
# Turn plane
|
59 |
+
planes[12, :, :] = int(board.turn)
|
60 |
+
# Castling rights
|
61 |
+
planes[13, :, :] = int(board.has_kingside_castling_rights(chess.WHITE))
|
62 |
+
planes[14, :, :] = int(board.has_queenside_castling_rights(chess.WHITE))
|
63 |
+
planes[15, :, :] = int(board.has_kingside_castling_rights(chess.BLACK))
|
64 |
+
planes[16, :, :] = int(board.has_queenside_castling_rights(chess.BLACK))
|
65 |
+
# En passant
|
66 |
+
if board.ep_square is not None:
|
67 |
+
row, col = divmod(board.ep_square, 8)
|
68 |
+
planes[17, row, col] = 1
|
69 |
+
return planes.flatten()
|
70 |
+
|
71 |
+
if __name__ == "__main__":
|
72 |
+
board = chess.Board()
|
73 |
+
net = O2Net()
|
74 |
+
x = torch.tensor(board_to_tensor(board)).unsqueeze(0)
|
75 |
+
policy, value = net(x)
|
76 |
+
print("Policy shape:", policy.shape)
|
77 |
+
print("Value:", value.item())
|