import chess import random import torch import torch.nn as nn import torch.optim as optim from o1.agent import Agent from o1.mcts import MCTS from o1.utils import save_board_svg, save_model class ExperienceBuffer: def __init__(self, max_size=10000): self.buffer = [] self.max_size = max_size def add(self, experience): if len(self.buffer) >= self.max_size: self.buffer.pop(0) self.buffer.append(experience) def sample(self, batch_size): return random.sample(self.buffer, min(batch_size, len(self.buffer))) def get_tensors(self, batch): # Convert batch of (state_tensor, policy, value) to tensors # Ensure state tensors are float32 and have correct shape states = torch.cat([s.float() for (s, _, _) in batch], dim=0) policies = torch.tensor([p for (_, p, _) in batch], dtype=torch.float32) values = torch.tensor([v for (_, _, v) in batch], dtype=torch.float32).unsqueeze(1) return states, policies, values def material_score(board, prev_board, color): """Return the material score difference for color after a move (positive if color captured, negative if lost).""" piece_values = { chess.PAWN: 1, chess.KNIGHT: 3, chess.BISHOP: 3, chess.ROOK: 5, chess.QUEEN: 9, chess.KING: 0 } def total(board, color): return sum(piece_values[p.piece_type] for p in board.piece_map().values() if p.color == color) return total(board, color) - total(prev_board, color) def self_play_game(agent, simulations=10, save_svg=False, svg_prefix="game", max_moves=40): # Randomly choose o1's color for this game o1_color = random.choice([chess.WHITE, chess.BLACK]) board = chess.Board() mcts = MCTS(agent, simulations=simulations) game_data = [] move_num = 0 print(f"o1 is playing as {'White' if o1_color == chess.WHITE else 'Black'}") prev_board = board.copy() while not board.is_game_over() and move_num < max_moves: # Determine if it's o1's turn o1_turn = (board.turn == o1_color) if o1_turn: move = mcts.search(board) else: # Opponent: random move move = random.choice(list(board.legal_moves)) print(f"Move {move_num + 1}: {move}") state_tensor = agent.board_to_tensor(board) policy = [0] * 4672 move_idx = list(board.legal_moves).index(move) policy[move_idx] = 1 value = 0 # Placeholder, will be set after game board.push(move) # Material reward: positive if o1 captures, negative if o1 loses material mat_reward = material_score(board, prev_board, o1_color) prev_board = board.copy() game_data.append((state_tensor, policy, value + mat_reward)) if save_svg: save_board_svg(board, f"{svg_prefix}_move{move_num}.svg") move_num += 1 print(f"Game ended after {move_num} moves") print(f"Final position:\n{board}") penalty = 0 if board.is_game_over(): outcome = board.outcome(claim_draw=True) if outcome: termination = outcome.termination.name if outcome.winner is None: if termination == "STALEMATE": winner_str = "Draw (stalemate)" z = 0 elif termination == "INSUFFICIENT_MATERIAL": winner_str = "Draw (insufficient material)" z = 0 penalty = z else: winner_str = f"Draw ({termination.lower()})" z = 0 penalty = z elif outcome.winner: winner_str = "White wins" if o1_color == chess.WHITE: z = 5 else: z = -1 # Penalize o1 if it was black and lost else: winner_str = "Black wins" if o1_color == chess.BLACK: z = 5 else: z = -1 # Penalize o1 if it was white and lost print(f"Game over reason: {board.result()} ({termination})") print(f"Result: {winner_str}") if penalty: print(f"Penalty applied: {penalty}") else: print(f"Game over reason: {board.result()} (unknown termination)") z = 0 print(f"Penalty applied: {z}") else: print("Game reached move limit - applying increased penalty") print("Result: No winner (move limit reached)") z = -2.0 print(f"Penalty applied: {z}") game_data = [(s, p, z) for (s, p, v) in game_data] if save_svg: save_board_svg(board, f"{svg_prefix}_final.svg") return game_data def train_step(agent, buffer, optimizer, batch_size=32, early_stopping=None, patience=5, min_delta=1e-3): if len(buffer.buffer) < batch_size: return batch = buffer.sample(batch_size) states, target_policies, target_values = buffer.get_tensors(batch) agent.model.train() optimizer.zero_grad() pred_policies, pred_values = agent.model(states) # Policy loss (cross-entropy) policy_loss = -torch.sum(target_policies * torch.log_softmax(pred_policies, dim=1)) / batch_size # Value loss (MSE) value_loss = nn.functional.mse_loss(pred_values, target_values) loss = policy_loss + value_loss loss.backward() optimizer.step() print(f"Train step: loss={loss.item():.4f} (policy={policy_loss.item():.4f}, value={value_loss.item():.4f})") # Early stopping logic (if enabled) if early_stopping is not None: if loss.item() < early_stopping['best_loss'] - min_delta: early_stopping['best_loss'] = loss.item() early_stopping['epochs_no_improve'] = 0 else: early_stopping['epochs_no_improve'] += 1 if early_stopping['epochs_no_improve'] >= patience: print("Early stopping triggered: no improvement.") return 'stop' def main(): agent = Agent() # Try to load pretrained weights if available import os from o1.utils import load_model pretrained_path = "trained_agent.pth" if os.path.exists(pretrained_path): print(f"Loading pretrained weights from {pretrained_path}...") load_model(agent, pretrained_path) else: print("No pretrained weights found. Training from scratch.") buffer = ExperienceBuffer() optimizer = optim.Adam(agent.model.parameters(), lr=1e-4) num_games = 10 # Increased from 50 for more training data global_reward = 0 for i in range(num_games): print(f"Self-play game {i+1}") # Only save video for the last game save_video = (i == num_games - 1) game_experience = self_play_game(agent, simulations=10, max_moves=150, save_svg=save_video, svg_prefix=f"final_game") for exp in game_experience: buffer.add(exp) # Log the reward for this game (all z are the same for the game) if game_experience: game_reward = game_experience[0][2] global_reward += game_reward print(f"Reward for this game: {game_reward}") print(f"Cumulative global reward: {global_reward}") train_step(agent, buffer, optimizer) print("Pipeline complete. Self-play now uses MCTS for move selection and real learning.") # Save the trained model at the end save_model(agent, "trained_agent.pth") print("Model saved as trained_agent.pth") if __name__ == "__main__": main()