FlameF0X commited on
Commit
cc24eeb
·
verified ·
1 Parent(s): 4867fd5

Upload 3 files

Browse files
Files changed (3) hide show
  1. src/mcts.py +88 -0
  2. src/o2_agent.py +78 -0
  3. src/o2_model.py +77 -0
src/mcts.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chess
2
+ import torch
3
+ import numpy as np
4
+ from o2_model import board_to_tensor
5
+
6
+ class MCTSNode:
7
+ def __init__(self, board, parent=None, move=None):
8
+ self.board = board.copy()
9
+ self.parent = parent
10
+ self.move = move
11
+ self.children = {}
12
+ self.N = 0 # Visit count
13
+ self.W = 0 # Total value
14
+ self.Q = 0 # Mean value
15
+ self.P = 0 # Prior probability
16
+
17
+ class MCTS:
18
+ def __init__(self, model, simulations=100, c_puct=1.5):
19
+ self.model = model
20
+ self.simulations = simulations
21
+ self.c_puct = c_puct
22
+
23
+ def run(self, board):
24
+ root = MCTSNode(board)
25
+ self._expand(root)
26
+ for _ in range(self.simulations):
27
+ node = root
28
+ search_path = [node]
29
+ # Selection
30
+ while node.children:
31
+ max_ucb = -float('inf')
32
+ best_move = None
33
+ for move, child in node.children.items():
34
+ ucb = child.Q + self.c_puct * child.P * np.sqrt(node.N) / (1 + child.N)
35
+ if ucb > max_ucb:
36
+ max_ucb = ucb
37
+ best_move = move
38
+ node = node.children[best_move]
39
+ search_path.append(node)
40
+ # Expansion
41
+ value = self._expand(node)
42
+ # Backpropagation
43
+ for n in reversed(search_path):
44
+ n.N += 1
45
+ n.W += value
46
+ n.Q = n.W / n.N
47
+ value = -value # Switch perspective
48
+ # Choose move with highest visit count
49
+ best_move = max(root.children.items(), key=lambda item: item[1].N)[0]
50
+ return best_move
51
+
52
+ def _expand(self, node):
53
+ if node.board.is_game_over():
54
+ result = node.board.result()
55
+ if result == '1-0':
56
+ return 1
57
+ elif result == '0-1':
58
+ return -1
59
+ else:
60
+ return 0
61
+ tensor = torch.tensor(board_to_tensor(node.board)).unsqueeze(0)
62
+ with torch.no_grad():
63
+ policy, value = self.model(tensor)
64
+ policy = torch.softmax(policy, dim=1).numpy()[0]
65
+ legal_moves = list(node.board.legal_moves)
66
+ total_p = 1e-8
67
+ for move in legal_moves:
68
+ idx = self.move_to_index(move)
69
+ p = policy[idx]
70
+ total_p += p
71
+ for move in legal_moves:
72
+ idx = self.move_to_index(move)
73
+ p = policy[idx] / total_p
74
+ child_board = node.board.copy()
75
+ child_board.push(move)
76
+ child = MCTSNode(child_board, parent=node, move=move)
77
+ child.P = p
78
+ node.children[move] = child
79
+ return value.item()
80
+
81
+ def move_to_index(self, move):
82
+ from_square = move.from_square
83
+ to_square = move.to_square
84
+ promotion = move.promotion if move.promotion else 0
85
+ promotion_offset = 0
86
+ if promotion:
87
+ promotion_offset = 4096 + (promotion - 1)
88
+ return from_square * 64 + to_square + promotion_offset
src/o2_agent.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chess
2
+ import torch
3
+ from o2_model import O2Net, board_to_tensor
4
+ from mcts import MCTS
5
+ import random
6
+
7
+ # Optional: Endgame tablebase and opening book integration placeholders
8
+ # You can use python-chess's tablebase and opening book modules if desired
9
+ # Example for endgame tablebase:
10
+ # from chess import tablebase
11
+ # tb = tablebase.Tablebase()
12
+ # tb.add_tablebase('/path/to/syzygy')
13
+ # if tb.probe_wdl(board) is not None:
14
+ # # Use tablebase move
15
+ # Example for opening book:
16
+ # from chess.polyglot import open_reader
17
+ # with open_reader('book.bin') as reader:
18
+ # entry = reader.find(board)
19
+ # move = entry.move
20
+
21
+ class O2Agent:
22
+ def __init__(self, model_path=None):
23
+ self.model = O2Net()
24
+ if model_path:
25
+ self.model.load_state_dict(torch.load(model_path))
26
+ self.model.eval()
27
+
28
+ def select_move(self, board, use_mcts=True, simulations=100):
29
+ if use_mcts:
30
+ mcts = MCTS(self.model, simulations=simulations)
31
+ return mcts.run(board)
32
+ tensor = torch.tensor(board_to_tensor(board)).unsqueeze(0)
33
+ with torch.no_grad():
34
+ policy, _ = self.model(tensor)
35
+ legal_moves = list(board.legal_moves)
36
+ move_scores = []
37
+ for move in legal_moves:
38
+ move_idx = self.move_to_index(move)
39
+ move_scores.append(policy[0, move_idx].item())
40
+ best_move = legal_moves[int(torch.tensor(move_scores).argmax())]
41
+ return best_move
42
+
43
+ def move_to_index(self, move):
44
+ # Encode move as from_square * 64 + to_square + promotion_offset
45
+ from_square = move.from_square
46
+ to_square = move.to_square
47
+ promotion = move.promotion if move.promotion else 0
48
+ promotion_offset = 0
49
+ if promotion:
50
+ # Promotion: 1=Knight, 2=Bishop, 3=Rook, 4=Queen (python-chess)
51
+ # Offset: 4096 + (promotion-1)*64*64//4
52
+ promotion_offset = 4096 + (promotion - 1) * 256
53
+ idx = from_square * 64 + to_square + promotion_offset
54
+ # Ensure index is within bounds
55
+ return idx if idx < 4672 else idx % 4672
56
+
57
+ def index_to_move(self, board, index):
58
+ # Decode index to move (reverse of move_to_index)
59
+ if index >= 4096:
60
+ promotion = (index - 4096) % 4 + 1
61
+ idx = index - 4096
62
+ from_square = idx // 64
63
+ to_square = idx % 64
64
+ move = chess.Move(from_square, to_square, promotion=promotion)
65
+ else:
66
+ from_square = index // 64
67
+ to_square = index % 64
68
+ move = chess.Move(from_square, to_square)
69
+ if move in board.legal_moves:
70
+ return move
71
+ # Fallback: pick a random legal move
72
+ return random.choice(list(board.legal_moves))
73
+
74
+ if __name__ == "__main__":
75
+ board = chess.Board()
76
+ agent = O2Agent()
77
+ move = agent.select_move(board)
78
+ print("O2 selects:", move)
src/o2_model.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chess
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ class O2Net(nn.Module):
8
+ def __init__(self):
9
+ super(O2Net, self).__init__()
10
+ # Input layer (updated to 1152 for 8x8x18 encoding)
11
+ self.input_fc = nn.Linear(1152, 1024)
12
+ # 10 deep residual blocks
13
+ self.res_blocks = nn.ModuleList([
14
+ nn.Sequential(
15
+ nn.Linear(1024, 1024),
16
+ nn.BatchNorm1d(1024),
17
+ nn.ReLU(),
18
+ nn.Linear(1024, 1024),
19
+ nn.BatchNorm1d(1024)
20
+ ) for _ in range(10)
21
+ ])
22
+ self.res_relu = nn.ReLU()
23
+ # Policy head
24
+ self.policy_fc1 = nn.Linear(1024, 512)
25
+ self.policy_fc2 = nn.Linear(512, 256)
26
+ self.policy_fc3 = nn.Linear(256, 4672)
27
+ # Value head
28
+ self.value_fc1 = nn.Linear(1024, 512)
29
+ self.value_fc2 = nn.Linear(512, 128)
30
+ self.value_fc3 = nn.Linear(128, 1)
31
+
32
+ def forward(self, x):
33
+ x = F.relu(self.input_fc(x))
34
+ for block in self.res_blocks:
35
+ residual = x
36
+ out = block(x)
37
+ x = self.res_relu(out + residual)
38
+ # Policy head
39
+ p = F.relu(self.policy_fc1(x))
40
+ p = F.relu(self.policy_fc2(p))
41
+ policy = self.policy_fc3(p)
42
+ # Value head
43
+ v = F.relu(self.value_fc1(x))
44
+ v = F.relu(self.value_fc2(v))
45
+ value = torch.tanh(self.value_fc3(v))
46
+ return policy, value
47
+
48
+ def board_to_tensor(board):
49
+ # Improved encoding: 8x8x18 planes (12 for pieces, 6 for state), flattened
50
+ # 12 planes: one for each piece type/color
51
+ # 6 planes: turn, castling rights (4), en passant
52
+ planes = np.zeros((18, 8, 8), dtype=np.float32)
53
+ piece_map = board.piece_map()
54
+ for square, piece in piece_map.items():
55
+ plane = (piece.piece_type - 1) + (0 if piece.color == chess.WHITE else 6)
56
+ row, col = divmod(square, 8)
57
+ planes[plane, row, col] = 1
58
+ # Turn plane
59
+ planes[12, :, :] = int(board.turn)
60
+ # Castling rights
61
+ planes[13, :, :] = int(board.has_kingside_castling_rights(chess.WHITE))
62
+ planes[14, :, :] = int(board.has_queenside_castling_rights(chess.WHITE))
63
+ planes[15, :, :] = int(board.has_kingside_castling_rights(chess.BLACK))
64
+ planes[16, :, :] = int(board.has_queenside_castling_rights(chess.BLACK))
65
+ # En passant
66
+ if board.ep_square is not None:
67
+ row, col = divmod(board.ep_square, 8)
68
+ planes[17, row, col] = 1
69
+ return planes.flatten()
70
+
71
+ if __name__ == "__main__":
72
+ board = chess.Board()
73
+ net = O2Net()
74
+ x = torch.tensor(board_to_tensor(board)).unsqueeze(0)
75
+ policy, value = net(x)
76
+ print("Policy shape:", policy.shape)
77
+ print("Value:", value.item())