Spaces:
Build error
Build error
import chess | |
import random | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from o1.agent import Agent | |
from o1.mcts import MCTS | |
from o1.utils import save_board_svg, save_model | |
class ExperienceBuffer: | |
def __init__(self, max_size=10000): | |
self.buffer = [] | |
self.max_size = max_size | |
def add(self, experience): | |
if len(self.buffer) >= self.max_size: | |
self.buffer.pop(0) | |
self.buffer.append(experience) | |
def sample(self, batch_size): | |
return random.sample(self.buffer, min(batch_size, len(self.buffer))) | |
def get_tensors(self, batch): | |
# Convert batch of (state_tensor, policy, value) to tensors | |
# Ensure state tensors are float32 and have correct shape | |
states = torch.cat([s.float() for (s, _, _) in batch], dim=0) | |
policies = torch.tensor([p for (_, p, _) in batch], dtype=torch.float32) | |
values = torch.tensor([v for (_, _, v) in batch], dtype=torch.float32).unsqueeze(1) | |
return states, policies, values | |
def material_score(board, prev_board, color): | |
"""Return the material score difference for color after a move (positive if color captured, negative if lost).""" | |
piece_values = { | |
chess.PAWN: 1, | |
chess.KNIGHT: 3, | |
chess.BISHOP: 3, | |
chess.ROOK: 5, | |
chess.QUEEN: 9, | |
chess.KING: 0 | |
} | |
def total(board, color): | |
return sum(piece_values[p.piece_type] for p in board.piece_map().values() if p.color == color) | |
return total(board, color) - total(prev_board, color) | |
def self_play_game(agent, simulations=10, save_svg=False, svg_prefix="game", max_moves=40): | |
# Randomly choose o1's color for this game | |
o1_color = random.choice([chess.WHITE, chess.BLACK]) | |
board = chess.Board() | |
mcts = MCTS(agent, simulations=simulations) | |
game_data = [] | |
move_num = 0 | |
print(f"o1 is playing as {'White' if o1_color == chess.WHITE else 'Black'}") | |
prev_board = board.copy() | |
while not board.is_game_over() and move_num < max_moves: | |
# Determine if it's o1's turn | |
o1_turn = (board.turn == o1_color) | |
if o1_turn: | |
move = mcts.search(board) | |
else: | |
# Opponent: random move | |
move = random.choice(list(board.legal_moves)) | |
print(f"Move {move_num + 1}: {move}") | |
state_tensor = agent.board_to_tensor(board) | |
policy = [0] * 4672 | |
move_idx = list(board.legal_moves).index(move) | |
policy[move_idx] = 1 | |
value = 0 # Placeholder, will be set after game | |
board.push(move) | |
# Material reward: positive if o1 captures, negative if o1 loses material | |
mat_reward = material_score(board, prev_board, o1_color) | |
prev_board = board.copy() | |
game_data.append((state_tensor, policy, value + mat_reward)) | |
if save_svg: | |
save_board_svg(board, f"{svg_prefix}_move{move_num}.svg") | |
move_num += 1 | |
print(f"Game ended after {move_num} moves") | |
print(f"Final position:\n{board}") | |
penalty = 0 | |
if board.is_game_over(): | |
outcome = board.outcome(claim_draw=True) | |
if outcome: | |
termination = outcome.termination.name | |
if outcome.winner is None: | |
if termination == "STALEMATE": | |
winner_str = "Draw (stalemate)" | |
z = 0 | |
elif termination == "INSUFFICIENT_MATERIAL": | |
winner_str = "Draw (insufficient material)" | |
z = 0 | |
penalty = z | |
else: | |
winner_str = f"Draw ({termination.lower()})" | |
z = 0 | |
penalty = z | |
elif outcome.winner: | |
winner_str = "White wins" | |
if o1_color == chess.WHITE: | |
z = 5 | |
else: | |
z = -1 # Penalize o1 if it was black and lost | |
else: | |
winner_str = "Black wins" | |
if o1_color == chess.BLACK: | |
z = 5 | |
else: | |
z = -1 # Penalize o1 if it was white and lost | |
print(f"Game over reason: {board.result()} ({termination})") | |
print(f"Result: {winner_str}") | |
if penalty: | |
print(f"Penalty applied: {penalty}") | |
else: | |
print(f"Game over reason: {board.result()} (unknown termination)") | |
z = 0 | |
print(f"Penalty applied: {z}") | |
else: | |
print("Game reached move limit - applying increased penalty") | |
print("Result: No winner (move limit reached)") | |
z = -2.0 | |
print(f"Penalty applied: {z}") | |
game_data = [(s, p, z) for (s, p, v) in game_data] | |
if save_svg: | |
save_board_svg(board, f"{svg_prefix}_final.svg") | |
return game_data | |
def train_step(agent, buffer, optimizer, batch_size=32, early_stopping=None, patience=5, min_delta=1e-3): | |
if len(buffer.buffer) < batch_size: | |
return | |
batch = buffer.sample(batch_size) | |
states, target_policies, target_values = buffer.get_tensors(batch) | |
agent.model.train() | |
optimizer.zero_grad() | |
pred_policies, pred_values = agent.model(states) | |
# Policy loss (cross-entropy) | |
policy_loss = -torch.sum(target_policies * torch.log_softmax(pred_policies, dim=1)) / batch_size | |
# Value loss (MSE) | |
value_loss = nn.functional.mse_loss(pred_values, target_values) | |
loss = policy_loss + value_loss | |
loss.backward() | |
optimizer.step() | |
print(f"Train step: loss={loss.item():.4f} (policy={policy_loss.item():.4f}, value={value_loss.item():.4f})") | |
# Early stopping logic (if enabled) | |
if early_stopping is not None: | |
if loss.item() < early_stopping['best_loss'] - min_delta: | |
early_stopping['best_loss'] = loss.item() | |
early_stopping['epochs_no_improve'] = 0 | |
else: | |
early_stopping['epochs_no_improve'] += 1 | |
if early_stopping['epochs_no_improve'] >= patience: | |
print("Early stopping triggered: no improvement.") | |
return 'stop' | |
def main(): | |
agent = Agent() | |
# Try to load pretrained weights if available | |
import os | |
from o1.utils import load_model | |
pretrained_path = "trained_agent.pth" | |
if os.path.exists(pretrained_path): | |
print(f"Loading pretrained weights from {pretrained_path}...") | |
load_model(agent, pretrained_path) | |
else: | |
print("No pretrained weights found. Training from scratch.") | |
buffer = ExperienceBuffer() | |
optimizer = optim.Adam(agent.model.parameters(), lr=1e-4) | |
num_games = 10 # Increased from 50 for more training data | |
global_reward = 0 | |
for i in range(num_games): | |
print(f"Self-play game {i+1}") | |
# Only save video for the last game | |
save_video = (i == num_games - 1) | |
game_experience = self_play_game(agent, simulations=10, max_moves=150, | |
save_svg=save_video, | |
svg_prefix=f"final_game") | |
for exp in game_experience: | |
buffer.add(exp) | |
# Log the reward for this game (all z are the same for the game) | |
if game_experience: | |
game_reward = game_experience[0][2] | |
global_reward += game_reward | |
print(f"Reward for this game: {game_reward}") | |
print(f"Cumulative global reward: {global_reward}") | |
train_step(agent, buffer, optimizer) | |
print("Pipeline complete. Self-play now uses MCTS for move selection and real learning.") | |
# Save the trained model at the end | |
save_model(agent, "trained_agent.pth") | |
print("Model saved as trained_agent.pth") | |
if __name__ == "__main__": | |
main() | |