Spaces:
Build error
Build error
Upload 11 files
Browse files- src/o1/__init__.py +1 -0
- src/o1/__pycache__/__init__.cpython-312.pyc +0 -0
- src/o1/__pycache__/agent.cpython-312.pyc +0 -0
- src/o1/__pycache__/mcts.cpython-312.pyc +0 -0
- src/o1/__pycache__/train.cpython-312.pyc +0 -0
- src/o1/__pycache__/utils.cpython-312.pyc +0 -0
- src/o1/agent.py +133 -0
- src/o1/mcts.py +96 -0
- src/o1/selfplay.py +37 -0
- src/o1/train.py +162 -0
- src/o1/utils.py +144 -0
src/o1/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# o1 package
|
src/o1/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (146 Bytes). View file
|
|
|
src/o1/__pycache__/agent.cpython-312.pyc
ADDED
|
Binary file (9.63 kB). View file
|
|
|
src/o1/__pycache__/mcts.cpython-312.pyc
ADDED
|
Binary file (5.4 kB). View file
|
|
|
src/o1/__pycache__/train.cpython-312.pyc
ADDED
|
Binary file (8.77 kB). View file
|
|
|
src/o1/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (7 kB). View file
|
|
|
src/o1/agent.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
import chess
|
| 6 |
+
|
| 7 |
+
class SEBlock(nn.Module):
|
| 8 |
+
def __init__(self, channels, reduction=16):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.fc1 = nn.Linear(channels, channels // reduction)
|
| 11 |
+
self.fc2 = nn.Linear(channels // reduction, channels)
|
| 12 |
+
|
| 13 |
+
def forward(self, x):
|
| 14 |
+
b, c, h, w = x.size()
|
| 15 |
+
y = x.view(b, c, -1).mean(dim=2)
|
| 16 |
+
y = F.relu(self.fc1(y))
|
| 17 |
+
y = torch.sigmoid(self.fc2(y))
|
| 18 |
+
y = y.view(b, c, 1, 1)
|
| 19 |
+
return x * y
|
| 20 |
+
|
| 21 |
+
class ResidualBlock(nn.Module):
|
| 22 |
+
def __init__(self, channels, dropout=0.2):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
| 25 |
+
self.bn1 = nn.BatchNorm2d(channels)
|
| 26 |
+
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
| 27 |
+
self.bn2 = nn.BatchNorm2d(channels)
|
| 28 |
+
self.se = SEBlock(channels)
|
| 29 |
+
self.dropout = nn.Dropout2d(dropout)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
residual = x
|
| 33 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 34 |
+
out = self.bn2(self.conv2(out))
|
| 35 |
+
out = self.se(out)
|
| 36 |
+
out = self.dropout(out)
|
| 37 |
+
out += residual
|
| 38 |
+
return F.relu(out)
|
| 39 |
+
|
| 40 |
+
class ChessNet(nn.Module):
|
| 41 |
+
def __init__(self, input_channels=17, board_size=8, policy_size=4672, num_blocks=20):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.conv_in = nn.Conv2d(input_channels, 256, kernel_size=3, padding=1)
|
| 44 |
+
self.bn_in = nn.BatchNorm2d(256)
|
| 45 |
+
self.res_blocks = nn.Sequential(*[ResidualBlock(256) for _ in range(num_blocks)])
|
| 46 |
+
self.fc1 = nn.Linear(256 * board_size * board_size, 512)
|
| 47 |
+
self.ln_fc1 = nn.LayerNorm(512)
|
| 48 |
+
# Policy head
|
| 49 |
+
self.policy_head1 = nn.Linear(512, 256)
|
| 50 |
+
self.policy_head2 = nn.Linear(256, policy_size)
|
| 51 |
+
# Value head
|
| 52 |
+
self.value_head1 = nn.Linear(512, 128)
|
| 53 |
+
self.value_head2 = nn.Linear(128, 1)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
x = F.relu(self.bn_in(self.conv_in(x)))
|
| 57 |
+
x = self.res_blocks(x)
|
| 58 |
+
x = x.view(x.size(0), -1)
|
| 59 |
+
x = F.relu(self.ln_fc1(self.fc1(x)))
|
| 60 |
+
# Policy head
|
| 61 |
+
policy = F.relu(self.policy_head1(x))
|
| 62 |
+
policy = self.policy_head2(policy)
|
| 63 |
+
# Value head
|
| 64 |
+
value = F.relu(self.value_head1(x))
|
| 65 |
+
value = torch.tanh(self.value_head2(value))
|
| 66 |
+
return policy, value
|
| 67 |
+
|
| 68 |
+
class Agent:
|
| 69 |
+
def __init__(self, device='cpu'):
|
| 70 |
+
self.device = device
|
| 71 |
+
self.model = ChessNet().to(device)
|
| 72 |
+
self.model.eval()
|
| 73 |
+
|
| 74 |
+
def board_to_tensor(self, board):
|
| 75 |
+
# 12x8x8 binary planes for piece types/colors
|
| 76 |
+
piece_map = board.piece_map()
|
| 77 |
+
tensor = np.zeros((17, 8, 8), dtype=np.float32)
|
| 78 |
+
for square, piece in piece_map.items():
|
| 79 |
+
idx = self.piece_to_index(piece)
|
| 80 |
+
row, col = divmod(square, 8)
|
| 81 |
+
tensor[idx, row, col] = 1
|
| 82 |
+
# Add castling rights (4 planes)
|
| 83 |
+
if board.has_kingside_castling_rights(chess.WHITE):
|
| 84 |
+
tensor[12, :, :] = 1
|
| 85 |
+
if board.has_queenside_castling_rights(chess.WHITE):
|
| 86 |
+
tensor[13, :, :] = 1
|
| 87 |
+
if board.has_kingside_castling_rights(chess.BLACK):
|
| 88 |
+
tensor[14, :, :] = 1
|
| 89 |
+
if board.has_queenside_castling_rights(chess.BLACK):
|
| 90 |
+
tensor[15, :, :] = 1
|
| 91 |
+
# Add move count (normalized, 1 plane)
|
| 92 |
+
tensor[16, :, :] = board.fullmove_number / 100.0
|
| 93 |
+
# Optionally, add repetition or other features here
|
| 94 |
+
return torch.tensor(tensor, device=self.device).unsqueeze(0)
|
| 95 |
+
|
| 96 |
+
def piece_to_index(self, piece):
|
| 97 |
+
# 0-5: white P,N,B,R,Q,K; 6-11: black P,N,B,R,Q,K
|
| 98 |
+
offset = 0 if piece.color == chess.WHITE else 6
|
| 99 |
+
piece_type_map = {
|
| 100 |
+
chess.PAWN: 0,
|
| 101 |
+
chess.KNIGHT: 1,
|
| 102 |
+
chess.BISHOP: 2,
|
| 103 |
+
chess.ROOK: 3,
|
| 104 |
+
chess.QUEEN: 4,
|
| 105 |
+
chess.KING: 5
|
| 106 |
+
}
|
| 107 |
+
return offset + piece_type_map[piece.piece_type]
|
| 108 |
+
|
| 109 |
+
def predict(self, board):
|
| 110 |
+
x = self.board_to_tensor(board)
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
policy_logits, value = self.model(x)
|
| 113 |
+
return policy_logits, value
|
| 114 |
+
|
| 115 |
+
def diffusion_sample(self, policy_logits, steps=10, noise_scale=1.0):
|
| 116 |
+
"""
|
| 117 |
+
Apply a simple diffusion process to the policy logits.
|
| 118 |
+
At each step, add Gaussian noise and denoise by averaging with the original logits.
|
| 119 |
+
"""
|
| 120 |
+
x = policy_logits.clone()
|
| 121 |
+
orig = policy_logits.clone()
|
| 122 |
+
for _ in range(steps):
|
| 123 |
+
noise = torch.randn_like(x) * noise_scale
|
| 124 |
+
x = x + noise
|
| 125 |
+
x = (x + orig) / 2 # simple denoising step
|
| 126 |
+
return x
|
| 127 |
+
|
| 128 |
+
def predict_with_diffusion(self, board, steps=10, noise_scale=1.0):
|
| 129 |
+
x = self.board_to_tensor(board)
|
| 130 |
+
with torch.no_grad():
|
| 131 |
+
policy_logits, value = self.model(x)
|
| 132 |
+
diffused_logits = self.diffusion_sample(policy_logits, steps=steps, noise_scale=noise_scale)
|
| 133 |
+
return diffused_logits, value
|
src/o1/mcts.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Monte Carlo Tree Search (MCTS) for o1 agent.
|
| 3 |
+
Basic implementation: runs simulations, selects moves by visit count.
|
| 4 |
+
Integrate with neural net for policy/value guidance for full strength.
|
| 5 |
+
"""
|
| 6 |
+
import chess
|
| 7 |
+
import random
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
class MCTSNode:
|
| 12 |
+
def __init__(self, board, parent=None, move=None):
|
| 13 |
+
self.board = board.copy()
|
| 14 |
+
self.parent = parent
|
| 15 |
+
self.move = move
|
| 16 |
+
self.children = []
|
| 17 |
+
self.visits = 0
|
| 18 |
+
self.value = 0.0
|
| 19 |
+
self.untried_moves = list(board.legal_moves)
|
| 20 |
+
|
| 21 |
+
def is_fully_expanded(self):
|
| 22 |
+
return len(self.untried_moves) == 0
|
| 23 |
+
|
| 24 |
+
def best_child(self, c_param=1.4):
|
| 25 |
+
choices = [
|
| 26 |
+
(child.value / (child.visits + 1e-6) + c_param * ( (2 * (self.visits + 1e-6)) ** 0.5 / (child.visits + 1e-6) ), child)
|
| 27 |
+
for child in self.children
|
| 28 |
+
]
|
| 29 |
+
return max(choices, key=lambda x: x[0])[1]
|
| 30 |
+
|
| 31 |
+
class MCTS:
|
| 32 |
+
def __init__(self, agent=None, simulations=50):
|
| 33 |
+
self.agent = agent
|
| 34 |
+
self.simulations = simulations
|
| 35 |
+
|
| 36 |
+
def search(self, board, restrict_top_n=None):
|
| 37 |
+
root = MCTSNode(board)
|
| 38 |
+
for _ in range(self.simulations):
|
| 39 |
+
node = root
|
| 40 |
+
sim_board = board.copy()
|
| 41 |
+
# Selection
|
| 42 |
+
while node.is_fully_expanded() and node.children:
|
| 43 |
+
node = node.best_child()
|
| 44 |
+
sim_board.push(node.move)
|
| 45 |
+
# Expansion
|
| 46 |
+
if node.untried_moves:
|
| 47 |
+
move = random.choice(node.untried_moves)
|
| 48 |
+
sim_board.push(move)
|
| 49 |
+
child = MCTSNode(sim_board, parent=node, move=move)
|
| 50 |
+
node.children.append(child)
|
| 51 |
+
node.untried_moves.remove(move)
|
| 52 |
+
node = child
|
| 53 |
+
# Simulation
|
| 54 |
+
result = self.simulate(sim_board)
|
| 55 |
+
# Backpropagation
|
| 56 |
+
# If it's black's turn at the node, invert the value for correct perspective
|
| 57 |
+
invert = False
|
| 58 |
+
temp_node = node
|
| 59 |
+
while temp_node.parent is not None:
|
| 60 |
+
temp_node = temp_node.parent
|
| 61 |
+
invert = not invert
|
| 62 |
+
value = -result if invert else result
|
| 63 |
+
while node:
|
| 64 |
+
node.visits += 1
|
| 65 |
+
node.value += value
|
| 66 |
+
node = node.parent
|
| 67 |
+
# Choose move with most visits, but restrict to top-N if specified
|
| 68 |
+
if not root.children:
|
| 69 |
+
return random.choice(list(board.legal_moves))
|
| 70 |
+
children_sorted = sorted(root.children, key=lambda c: c.visits, reverse=True)
|
| 71 |
+
if restrict_top_n is not None and restrict_top_n < len(children_sorted):
|
| 72 |
+
# Only consider top-N moves
|
| 73 |
+
children_sorted = children_sorted[:restrict_top_n]
|
| 74 |
+
best = max(children_sorted, key=lambda c: c.visits)
|
| 75 |
+
return best.move
|
| 76 |
+
|
| 77 |
+
def simulate(self, board, use_diffusion=True, diffusion_steps=10, noise_scale=1.0):
|
| 78 |
+
# Use neural network to evaluate the board instead of random playout
|
| 79 |
+
if self.agent is not None:
|
| 80 |
+
with torch.no_grad():
|
| 81 |
+
if use_diffusion and hasattr(self.agent, 'predict_with_diffusion'):
|
| 82 |
+
_, value = self.agent.predict_with_diffusion(board, steps=diffusion_steps, noise_scale=noise_scale)
|
| 83 |
+
else:
|
| 84 |
+
_, value = self.agent.predict(board)
|
| 85 |
+
return value.item()
|
| 86 |
+
# Fallback: play random moves until game ends
|
| 87 |
+
while not board.is_game_over():
|
| 88 |
+
move = random.choice(list(board.legal_moves))
|
| 89 |
+
board.push(move)
|
| 90 |
+
result = board.result()
|
| 91 |
+
if result == '1-0':
|
| 92 |
+
return 1
|
| 93 |
+
elif result == '0-1':
|
| 94 |
+
return -1
|
| 95 |
+
else:
|
| 96 |
+
return 0
|
src/o1/selfplay.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Self-play orchestration for o1 agent.
|
| 3 |
+
Runs self-play games using MCTS for move selection.
|
| 4 |
+
"""
|
| 5 |
+
import chess
|
| 6 |
+
from o1.mcts import MCTS
|
| 7 |
+
|
| 8 |
+
def run_selfplay(agent, num_games=1, simulations=50):
|
| 9 |
+
"""Run self-play games using MCTS and return experience."""
|
| 10 |
+
all_experience = []
|
| 11 |
+
for game_idx in range(num_games):
|
| 12 |
+
board = chess.Board()
|
| 13 |
+
mcts = MCTS(agent, simulations=simulations)
|
| 14 |
+
game_data = []
|
| 15 |
+
while not board.is_game_over():
|
| 16 |
+
move = mcts.search(board)
|
| 17 |
+
state_tensor = agent.board_to_tensor(board)
|
| 18 |
+
# Policy: one-hot for chosen move (for now)
|
| 19 |
+
policy = [0] * 4672 # 4672 is max legal moves in chess
|
| 20 |
+
move_idx = list(board.legal_moves).index(move)
|
| 21 |
+
policy[move_idx] = 1
|
| 22 |
+
value = 0 # Placeholder, will be set after game
|
| 23 |
+
game_data.append((state_tensor, policy, value))
|
| 24 |
+
board.push(move)
|
| 25 |
+
# Assign final result as value for all positions
|
| 26 |
+
result = board.result()
|
| 27 |
+
if result == '1-0':
|
| 28 |
+
z = 5
|
| 29 |
+
elif result == '0-1':
|
| 30 |
+
z = -1
|
| 31 |
+
else:
|
| 32 |
+
z = 0
|
| 33 |
+
game_data = [(s, p, z) for (s, p, v) in game_data]
|
| 34 |
+
all_experience.extend(game_data)
|
| 35 |
+
return all_experience
|
| 36 |
+
|
| 37 |
+
# Self-play loop implementation will go here
|
src/o1/train.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import chess
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
from o1.agent import Agent
|
| 7 |
+
from o1.mcts import MCTS
|
| 8 |
+
from o1.utils import save_board_svg, save_model
|
| 9 |
+
|
| 10 |
+
class ExperienceBuffer:
|
| 11 |
+
def __init__(self, max_size=10000):
|
| 12 |
+
self.buffer = []
|
| 13 |
+
self.max_size = max_size
|
| 14 |
+
def add(self, experience):
|
| 15 |
+
if len(self.buffer) >= self.max_size:
|
| 16 |
+
self.buffer.pop(0)
|
| 17 |
+
self.buffer.append(experience)
|
| 18 |
+
def sample(self, batch_size):
|
| 19 |
+
return random.sample(self.buffer, min(batch_size, len(self.buffer)))
|
| 20 |
+
def get_tensors(self, batch):
|
| 21 |
+
# Convert batch of (state_tensor, policy, value) to tensors
|
| 22 |
+
# Ensure state tensors are float32 and have correct shape
|
| 23 |
+
states = torch.cat([s.float() for (s, _, _) in batch], dim=0)
|
| 24 |
+
policies = torch.tensor([p for (_, p, _) in batch], dtype=torch.float32)
|
| 25 |
+
values = torch.tensor([v for (_, _, v) in batch], dtype=torch.float32).unsqueeze(1)
|
| 26 |
+
return states, policies, values
|
| 27 |
+
|
| 28 |
+
def self_play_game(agent, simulations=10, save_svg=False, svg_prefix="game", max_moves=40):
|
| 29 |
+
# Randomly choose o1's color for this game
|
| 30 |
+
o1_color = random.choice([chess.WHITE, chess.BLACK])
|
| 31 |
+
board = chess.Board()
|
| 32 |
+
mcts = MCTS(agent, simulations=simulations)
|
| 33 |
+
game_data = []
|
| 34 |
+
move_num = 0
|
| 35 |
+
print(f"o1 is playing as {'White' if o1_color == chess.WHITE else 'Black'}")
|
| 36 |
+
|
| 37 |
+
while not board.is_game_over() and move_num < max_moves:
|
| 38 |
+
# Determine if it's o1's turn
|
| 39 |
+
o1_turn = (board.turn == o1_color)
|
| 40 |
+
if o1_turn:
|
| 41 |
+
move = mcts.search(board)
|
| 42 |
+
else:
|
| 43 |
+
# Opponent: random move
|
| 44 |
+
move = random.choice(list(board.legal_moves))
|
| 45 |
+
print(f"Move {move_num + 1}: {move}")
|
| 46 |
+
state_tensor = agent.board_to_tensor(board)
|
| 47 |
+
policy = [0] * 4672
|
| 48 |
+
move_idx = list(board.legal_moves).index(move)
|
| 49 |
+
policy[move_idx] = 1
|
| 50 |
+
value = 0 # Placeholder, will be set after game
|
| 51 |
+
game_data.append((state_tensor, policy, value))
|
| 52 |
+
board.push(move)
|
| 53 |
+
if save_svg:
|
| 54 |
+
save_board_svg(board, f"{svg_prefix}_move{move_num}.svg")
|
| 55 |
+
move_num += 1
|
| 56 |
+
|
| 57 |
+
print(f"Game ended after {move_num} moves")
|
| 58 |
+
print(f"Final position:\n{board}")
|
| 59 |
+
|
| 60 |
+
penalty = 0
|
| 61 |
+
if board.is_game_over():
|
| 62 |
+
outcome = board.outcome(claim_draw=True)
|
| 63 |
+
if outcome:
|
| 64 |
+
termination = outcome.termination.name
|
| 65 |
+
if outcome.winner is None:
|
| 66 |
+
if termination == "STALEMATE":
|
| 67 |
+
winner_str = "Draw (stalemate)"
|
| 68 |
+
z = 0
|
| 69 |
+
elif termination == "INSUFFICIENT_MATERIAL":
|
| 70 |
+
winner_str = "Draw (insufficient material)"
|
| 71 |
+
z = 0
|
| 72 |
+
penalty = z
|
| 73 |
+
else:
|
| 74 |
+
winner_str = f"Draw ({termination.lower()})"
|
| 75 |
+
z = 0
|
| 76 |
+
penalty = z
|
| 77 |
+
elif outcome.winner:
|
| 78 |
+
winner_str = "White wins"
|
| 79 |
+
if o1_color == chess.WHITE:
|
| 80 |
+
z = 5
|
| 81 |
+
else:
|
| 82 |
+
z = -1 # Penalize o1 if it was black and lost
|
| 83 |
+
else:
|
| 84 |
+
winner_str = "Black wins"
|
| 85 |
+
if o1_color == chess.BLACK:
|
| 86 |
+
z = 5
|
| 87 |
+
else:
|
| 88 |
+
z = -1 # Penalize o1 if it was white and lost
|
| 89 |
+
print(f"Game over reason: {board.result()} ({termination})")
|
| 90 |
+
print(f"Result: {winner_str}")
|
| 91 |
+
if penalty:
|
| 92 |
+
print(f"Penalty applied: {penalty}")
|
| 93 |
+
else:
|
| 94 |
+
print(f"Game over reason: {board.result()} (unknown termination)")
|
| 95 |
+
z = 0
|
| 96 |
+
print(f"Penalty applied: {z}")
|
| 97 |
+
else:
|
| 98 |
+
print("Game reached move limit - applying increased penalty")
|
| 99 |
+
print("Result: No winner (move limit reached)")
|
| 100 |
+
z = -2.0
|
| 101 |
+
print(f"Penalty applied: {z}")
|
| 102 |
+
|
| 103 |
+
game_data = [(s, p, z) for (s, p, v) in game_data]
|
| 104 |
+
if save_svg:
|
| 105 |
+
save_board_svg(board, f"{svg_prefix}_final.svg")
|
| 106 |
+
return game_data
|
| 107 |
+
|
| 108 |
+
def train_step(agent, buffer, optimizer, batch_size=32):
|
| 109 |
+
if len(buffer.buffer) < batch_size:
|
| 110 |
+
return
|
| 111 |
+
batch = buffer.sample(batch_size)
|
| 112 |
+
states, target_policies, target_values = buffer.get_tensors(batch)
|
| 113 |
+
agent.model.train()
|
| 114 |
+
optimizer.zero_grad()
|
| 115 |
+
pred_policies, pred_values = agent.model(states)
|
| 116 |
+
# Policy loss (cross-entropy)
|
| 117 |
+
policy_loss = -torch.sum(target_policies * torch.log_softmax(pred_policies, dim=1)) / batch_size
|
| 118 |
+
# Value loss (MSE)
|
| 119 |
+
value_loss = nn.functional.mse_loss(pred_values, target_values)
|
| 120 |
+
loss = policy_loss + value_loss
|
| 121 |
+
loss.backward()
|
| 122 |
+
optimizer.step()
|
| 123 |
+
print(f"Train step: loss={loss.item():.4f} (policy={policy_loss.item():.4f}, value={value_loss.item():.4f})")
|
| 124 |
+
|
| 125 |
+
def main():
|
| 126 |
+
agent = Agent()
|
| 127 |
+
# Try to load pretrained weights if available
|
| 128 |
+
import os
|
| 129 |
+
from o1.utils import load_model
|
| 130 |
+
pretrained_path = "trained_agent.pth"
|
| 131 |
+
if os.path.exists(pretrained_path):
|
| 132 |
+
print(f"Loading pretrained weights from {pretrained_path}...")
|
| 133 |
+
load_model(agent, pretrained_path)
|
| 134 |
+
else:
|
| 135 |
+
print("No pretrained weights found. Training from scratch.")
|
| 136 |
+
buffer = ExperienceBuffer()
|
| 137 |
+
optimizer = optim.Adam(agent.model.parameters(), lr=1e-4)
|
| 138 |
+
num_games = 20 # Increased from 50 for more training data
|
| 139 |
+
global_reward = 0
|
| 140 |
+
for i in range(num_games):
|
| 141 |
+
print(f"Self-play game {i+1}")
|
| 142 |
+
# Only save video for the last game
|
| 143 |
+
save_video = (i == num_games - 1)
|
| 144 |
+
game_experience = self_play_game(agent, simulations=10, max_moves=300,
|
| 145 |
+
save_svg=save_video,
|
| 146 |
+
svg_prefix=f"final_game")
|
| 147 |
+
for exp in game_experience:
|
| 148 |
+
buffer.add(exp)
|
| 149 |
+
# Log the reward for this game (all z are the same for the game)
|
| 150 |
+
if game_experience:
|
| 151 |
+
game_reward = game_experience[0][2]
|
| 152 |
+
global_reward += game_reward
|
| 153 |
+
print(f"Reward for this game: {game_reward}")
|
| 154 |
+
print(f"Cumulative global reward: {global_reward}")
|
| 155 |
+
train_step(agent, buffer, optimizer)
|
| 156 |
+
print("Pipeline complete. Self-play now uses MCTS for move selection and real learning.")
|
| 157 |
+
# Save the trained model at the end
|
| 158 |
+
save_model(agent, "trained_agent.pth")
|
| 159 |
+
print("Model saved as trained_agent.pth")
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
main()
|
src/o1/utils.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for o1 agent.
|
| 3 |
+
Includes ELO calculation, FEN helpers, experience save/load, and more.
|
| 4 |
+
"""
|
| 5 |
+
import pickle
|
| 6 |
+
import chess
|
| 7 |
+
import torch
|
| 8 |
+
import chess.svg
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
def calculate_elo(rating_a, rating_b, result, k=32):
|
| 12 |
+
"""Update ELO rating for player A given result (1=win, 0.5=draw, 0=loss)."""
|
| 13 |
+
expected = 1 / (1 + 10 ** ((rating_b - rating_a) / 400))
|
| 14 |
+
new_rating = rating_a + k * (result - expected)
|
| 15 |
+
return new_rating
|
| 16 |
+
|
| 17 |
+
def board_to_fen(board):
|
| 18 |
+
"""Convert a chess.Board to FEN string."""
|
| 19 |
+
return board.fen()
|
| 20 |
+
|
| 21 |
+
def fen_to_board(fen):
|
| 22 |
+
"""Convert a FEN string to chess.Board."""
|
| 23 |
+
return chess.Board(fen)
|
| 24 |
+
|
| 25 |
+
def save_experience(buffer, filename):
|
| 26 |
+
"""Save experience buffer to file."""
|
| 27 |
+
with open(filename, 'wb') as f:
|
| 28 |
+
pickle.dump(buffer.buffer, f)
|
| 29 |
+
|
| 30 |
+
def load_experience(filename):
|
| 31 |
+
"""Load experience buffer from file."""
|
| 32 |
+
with open(filename, 'rb') as f:
|
| 33 |
+
return pickle.load(f)
|
| 34 |
+
|
| 35 |
+
def save_model(agent, filename):
|
| 36 |
+
"""Save the agent's model to a file."""
|
| 37 |
+
torch.save(agent.model.state_dict(), filename)
|
| 38 |
+
|
| 39 |
+
def load_model(agent, filename):
|
| 40 |
+
"""Load the agent's model from a file."""
|
| 41 |
+
agent.model.load_state_dict(torch.load(filename))
|
| 42 |
+
agent.model.eval()
|
| 43 |
+
|
| 44 |
+
def save_board_svg(board, filename):
|
| 45 |
+
"""Save the current board position as an SVG image."""
|
| 46 |
+
svg_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "game_svgs")
|
| 47 |
+
os.makedirs(svg_dir, exist_ok=True)
|
| 48 |
+
filepath = os.path.join(svg_dir, filename)
|
| 49 |
+
svg = chess.svg.board(board=board)
|
| 50 |
+
with open(filepath, 'w') as f:
|
| 51 |
+
f.write(svg)
|
| 52 |
+
|
| 53 |
+
# Option 3: Model Architecture Tuning
|
| 54 |
+
# Try deeper/smaller networks, different block types, or alternative architectures.
|
| 55 |
+
|
| 56 |
+
def try_alternative_architectures(agent_class, architectures):
|
| 57 |
+
"""Try different model architectures and return their performance."""
|
| 58 |
+
results = {}
|
| 59 |
+
for arch in architectures:
|
| 60 |
+
agent = agent_class(arch=arch)
|
| 61 |
+
# Evaluate agent (placeholder, implement evaluation logic)
|
| 62 |
+
results[arch] = None # Fill with actual evaluation
|
| 63 |
+
return results
|
| 64 |
+
|
| 65 |
+
# Option 4: Hyperparameter Tuning
|
| 66 |
+
# Use grid/random search to find optimal hyperparameters.
|
| 67 |
+
|
| 68 |
+
def grid_search(train_func, param_grid):
|
| 69 |
+
"""Perform grid search over hyperparameters."""
|
| 70 |
+
import itertools
|
| 71 |
+
keys, values = zip(*param_grid.items())
|
| 72 |
+
best_score = None
|
| 73 |
+
best_params = None
|
| 74 |
+
for v in itertools.product(*values):
|
| 75 |
+
params = dict(zip(keys, v))
|
| 76 |
+
score = train_func(**params)
|
| 77 |
+
if best_score is None or score > best_score:
|
| 78 |
+
best_score = score
|
| 79 |
+
best_params = params
|
| 80 |
+
return best_params, best_score
|
| 81 |
+
|
| 82 |
+
# Option 5: Regularization
|
| 83 |
+
# Add dropout, L2 regularization, or early stopping.
|
| 84 |
+
|
| 85 |
+
def add_regularization(model, dropout=0.2, l2=1e-4):
|
| 86 |
+
"""Add dropout and L2 regularization to the model."""
|
| 87 |
+
# This is a placeholder; actual implementation depends on model code
|
| 88 |
+
for module in model.modules():
|
| 89 |
+
if hasattr(module, 'dropout'):
|
| 90 |
+
module.dropout.p = dropout
|
| 91 |
+
return model, l2
|
| 92 |
+
|
| 93 |
+
# Option 6: Cross-Validation
|
| 94 |
+
# Use k-fold cross-validation for robust evaluation.
|
| 95 |
+
|
| 96 |
+
def k_fold_cross_validation(train_func, k=5, *args, **kwargs):
|
| 97 |
+
"""Perform k-fold cross-validation."""
|
| 98 |
+
scores = []
|
| 99 |
+
for i in range(k):
|
| 100 |
+
score = train_func(fold=i, *args, **kwargs)
|
| 101 |
+
scores.append(score)
|
| 102 |
+
return sum(scores) / len(scores)
|
| 103 |
+
|
| 104 |
+
# Option 7: Ensemble Methods
|
| 105 |
+
# Combine multiple models for better performance.
|
| 106 |
+
|
| 107 |
+
def ensemble_predict(models, input_tensor):
|
| 108 |
+
"""Average predictions from multiple models."""
|
| 109 |
+
outputs = [model(input_tensor) for model in models]
|
| 110 |
+
# Assume outputs are tuples (policy, value)
|
| 111 |
+
avg_policy = sum([o[0] for o in outputs]) / len(outputs)
|
| 112 |
+
avg_value = sum([o[1] for o in outputs]) / len(outputs)
|
| 113 |
+
return avg_policy, avg_value
|
| 114 |
+
|
| 115 |
+
# ELO Evaluation for Model
|
| 116 |
+
|
| 117 |
+
def evaluate_model_elo(agent, opponent, num_games=20, initial_elo=1500):
|
| 118 |
+
"""Play games between agent and opponent, return estimated ELO for agent."""
|
| 119 |
+
agent_elo = initial_elo
|
| 120 |
+
opp_elo = initial_elo
|
| 121 |
+
import random
|
| 122 |
+
import chess
|
| 123 |
+
from o1.mcts import MCTS
|
| 124 |
+
for i in range(num_games):
|
| 125 |
+
board = chess.Board()
|
| 126 |
+
mcts_agent = MCTS(agent)
|
| 127 |
+
mcts_opp = MCTS(opponent)
|
| 128 |
+
turn = random.choice([True, False])
|
| 129 |
+
while not board.is_game_over():
|
| 130 |
+
if board.turn == turn:
|
| 131 |
+
move = mcts_agent.search(board)
|
| 132 |
+
else:
|
| 133 |
+
move = mcts_opp.search(board)
|
| 134 |
+
board.push(move)
|
| 135 |
+
result = board.result()
|
| 136 |
+
if result == '1-0':
|
| 137 |
+
agent_score = 1 if turn else 0
|
| 138 |
+
elif result == '0-1':
|
| 139 |
+
agent_score = 0 if turn else 1
|
| 140 |
+
else:
|
| 141 |
+
agent_score = 0.5
|
| 142 |
+
agent_elo = calculate_elo(agent_elo, opp_elo, agent_score)
|
| 143 |
+
opp_elo = calculate_elo(opp_elo, agent_elo, 1 - agent_score)
|
| 144 |
+
return agent_elo
|