Spaces:
Build error
Build error
File size: 7,902 Bytes
8806ce1 00ab121 8806ce1 00ab121 8806ce1 00ab121 8806ce1 00ab121 8806ce1 00ab121 8806ce1 596f68b 8806ce1 00ab121 8806ce1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import chess
import random
import torch
import torch.nn as nn
import torch.optim as optim
from o1.agent import Agent
from o1.mcts import MCTS
from o1.utils import save_board_svg, save_model
class ExperienceBuffer:
def __init__(self, max_size=10000):
self.buffer = []
self.max_size = max_size
def add(self, experience):
if len(self.buffer) >= self.max_size:
self.buffer.pop(0)
self.buffer.append(experience)
def sample(self, batch_size):
return random.sample(self.buffer, min(batch_size, len(self.buffer)))
def get_tensors(self, batch):
# Convert batch of (state_tensor, policy, value) to tensors
# Ensure state tensors are float32 and have correct shape
states = torch.cat([s.float() for (s, _, _) in batch], dim=0)
policies = torch.tensor([p for (_, p, _) in batch], dtype=torch.float32)
values = torch.tensor([v for (_, _, v) in batch], dtype=torch.float32).unsqueeze(1)
return states, policies, values
def material_score(board, prev_board, color):
"""Return the material score difference for color after a move (positive if color captured, negative if lost)."""
piece_values = {
chess.PAWN: 1,
chess.KNIGHT: 3,
chess.BISHOP: 3,
chess.ROOK: 5,
chess.QUEEN: 9,
chess.KING: 0
}
def total(board, color):
return sum(piece_values[p.piece_type] for p in board.piece_map().values() if p.color == color)
return total(board, color) - total(prev_board, color)
def self_play_game(agent, simulations=10, save_svg=False, svg_prefix="game", max_moves=40):
# Randomly choose o1's color for this game
o1_color = random.choice([chess.WHITE, chess.BLACK])
board = chess.Board()
mcts = MCTS(agent, simulations=simulations)
game_data = []
move_num = 0
print(f"o1 is playing as {'White' if o1_color == chess.WHITE else 'Black'}")
prev_board = board.copy()
while not board.is_game_over() and move_num < max_moves:
# Determine if it's o1's turn
o1_turn = (board.turn == o1_color)
if o1_turn:
move = mcts.search(board)
else:
# Opponent: random move
move = random.choice(list(board.legal_moves))
print(f"Move {move_num + 1}: {move}")
state_tensor = agent.board_to_tensor(board)
policy = [0] * 4672
move_idx = list(board.legal_moves).index(move)
policy[move_idx] = 1
value = 0 # Placeholder, will be set after game
board.push(move)
# Material reward: positive if o1 captures, negative if o1 loses material
mat_reward = material_score(board, prev_board, o1_color)
prev_board = board.copy()
game_data.append((state_tensor, policy, value + mat_reward))
if save_svg:
save_board_svg(board, f"{svg_prefix}_move{move_num}.svg")
move_num += 1
print(f"Game ended after {move_num} moves")
print(f"Final position:\n{board}")
penalty = 0
if board.is_game_over():
outcome = board.outcome(claim_draw=True)
if outcome:
termination = outcome.termination.name
if outcome.winner is None:
if termination == "STALEMATE":
winner_str = "Draw (stalemate)"
z = 0
elif termination == "INSUFFICIENT_MATERIAL":
winner_str = "Draw (insufficient material)"
z = 0
penalty = z
else:
winner_str = f"Draw ({termination.lower()})"
z = 0
penalty = z
elif outcome.winner:
winner_str = "White wins"
if o1_color == chess.WHITE:
z = 5
else:
z = -1 # Penalize o1 if it was black and lost
else:
winner_str = "Black wins"
if o1_color == chess.BLACK:
z = 5
else:
z = -1 # Penalize o1 if it was white and lost
print(f"Game over reason: {board.result()} ({termination})")
print(f"Result: {winner_str}")
if penalty:
print(f"Penalty applied: {penalty}")
else:
print(f"Game over reason: {board.result()} (unknown termination)")
z = 0
print(f"Penalty applied: {z}")
else:
print("Game reached move limit - applying increased penalty")
print("Result: No winner (move limit reached)")
z = -2.0
print(f"Penalty applied: {z}")
game_data = [(s, p, z) for (s, p, v) in game_data]
if save_svg:
save_board_svg(board, f"{svg_prefix}_final.svg")
return game_data
def train_step(agent, buffer, optimizer, batch_size=32, early_stopping=None, patience=5, min_delta=1e-3):
if len(buffer.buffer) < batch_size:
return
batch = buffer.sample(batch_size)
states, target_policies, target_values = buffer.get_tensors(batch)
agent.model.train()
optimizer.zero_grad()
pred_policies, pred_values = agent.model(states)
# Policy loss (cross-entropy)
policy_loss = -torch.sum(target_policies * torch.log_softmax(pred_policies, dim=1)) / batch_size
# Value loss (MSE)
value_loss = nn.functional.mse_loss(pred_values, target_values)
loss = policy_loss + value_loss
loss.backward()
optimizer.step()
print(f"Train step: loss={loss.item():.4f} (policy={policy_loss.item():.4f}, value={value_loss.item():.4f})")
# Early stopping logic (if enabled)
if early_stopping is not None:
if loss.item() < early_stopping['best_loss'] - min_delta:
early_stopping['best_loss'] = loss.item()
early_stopping['epochs_no_improve'] = 0
else:
early_stopping['epochs_no_improve'] += 1
if early_stopping['epochs_no_improve'] >= patience:
print("Early stopping triggered: no improvement.")
return 'stop'
def main():
agent = Agent()
# Try to load pretrained weights if available
import os
from o1.utils import load_model
pretrained_path = "trained_agent.pth"
if os.path.exists(pretrained_path):
print(f"Loading pretrained weights from {pretrained_path}...")
load_model(agent, pretrained_path)
else:
print("No pretrained weights found. Training from scratch.")
buffer = ExperienceBuffer()
optimizer = optim.Adam(agent.model.parameters(), lr=1e-4)
num_games = 10 # Increased from 50 for more training data
global_reward = 0
for i in range(num_games):
print(f"Self-play game {i+1}")
# Only save video for the last game
save_video = (i == num_games - 1)
game_experience = self_play_game(agent, simulations=10, max_moves=150,
save_svg=save_video,
svg_prefix=f"final_game")
for exp in game_experience:
buffer.add(exp)
# Log the reward for this game (all z are the same for the game)
if game_experience:
game_reward = game_experience[0][2]
global_reward += game_reward
print(f"Reward for this game: {game_reward}")
print(f"Cumulative global reward: {global_reward}")
train_step(agent, buffer, optimizer)
print("Pipeline complete. Self-play now uses MCTS for move selection and real learning.")
# Save the trained model at the end
save_model(agent, "trained_agent.pth")
print("Model saved as trained_agent.pth")
if __name__ == "__main__":
main()
|