play-with-o1 / src /o1 /train.py
FlameF0X's picture
Upload 6 files
596f68b verified
import chess
import random
import torch
import torch.nn as nn
import torch.optim as optim
from o1.agent import Agent
from o1.mcts import MCTS
from o1.utils import save_board_svg, save_model
class ExperienceBuffer:
def __init__(self, max_size=10000):
self.buffer = []
self.max_size = max_size
def add(self, experience):
if len(self.buffer) >= self.max_size:
self.buffer.pop(0)
self.buffer.append(experience)
def sample(self, batch_size):
return random.sample(self.buffer, min(batch_size, len(self.buffer)))
def get_tensors(self, batch):
# Convert batch of (state_tensor, policy, value) to tensors
# Ensure state tensors are float32 and have correct shape
states = torch.cat([s.float() for (s, _, _) in batch], dim=0)
policies = torch.tensor([p for (_, p, _) in batch], dtype=torch.float32)
values = torch.tensor([v for (_, _, v) in batch], dtype=torch.float32).unsqueeze(1)
return states, policies, values
def material_score(board, prev_board, color):
"""Return the material score difference for color after a move (positive if color captured, negative if lost)."""
piece_values = {
chess.PAWN: 1,
chess.KNIGHT: 3,
chess.BISHOP: 3,
chess.ROOK: 5,
chess.QUEEN: 9,
chess.KING: 0
}
def total(board, color):
return sum(piece_values[p.piece_type] for p in board.piece_map().values() if p.color == color)
return total(board, color) - total(prev_board, color)
def self_play_game(agent, simulations=10, save_svg=False, svg_prefix="game", max_moves=40):
# Randomly choose o1's color for this game
o1_color = random.choice([chess.WHITE, chess.BLACK])
board = chess.Board()
mcts = MCTS(agent, simulations=simulations)
game_data = []
move_num = 0
print(f"o1 is playing as {'White' if o1_color == chess.WHITE else 'Black'}")
prev_board = board.copy()
while not board.is_game_over() and move_num < max_moves:
# Determine if it's o1's turn
o1_turn = (board.turn == o1_color)
if o1_turn:
move = mcts.search(board)
else:
# Opponent: random move
move = random.choice(list(board.legal_moves))
print(f"Move {move_num + 1}: {move}")
state_tensor = agent.board_to_tensor(board)
policy = [0] * 4672
move_idx = list(board.legal_moves).index(move)
policy[move_idx] = 1
value = 0 # Placeholder, will be set after game
board.push(move)
# Material reward: positive if o1 captures, negative if o1 loses material
mat_reward = material_score(board, prev_board, o1_color)
prev_board = board.copy()
game_data.append((state_tensor, policy, value + mat_reward))
if save_svg:
save_board_svg(board, f"{svg_prefix}_move{move_num}.svg")
move_num += 1
print(f"Game ended after {move_num} moves")
print(f"Final position:\n{board}")
penalty = 0
if board.is_game_over():
outcome = board.outcome(claim_draw=True)
if outcome:
termination = outcome.termination.name
if outcome.winner is None:
if termination == "STALEMATE":
winner_str = "Draw (stalemate)"
z = 0
elif termination == "INSUFFICIENT_MATERIAL":
winner_str = "Draw (insufficient material)"
z = 0
penalty = z
else:
winner_str = f"Draw ({termination.lower()})"
z = 0
penalty = z
elif outcome.winner:
winner_str = "White wins"
if o1_color == chess.WHITE:
z = 5
else:
z = -1 # Penalize o1 if it was black and lost
else:
winner_str = "Black wins"
if o1_color == chess.BLACK:
z = 5
else:
z = -1 # Penalize o1 if it was white and lost
print(f"Game over reason: {board.result()} ({termination})")
print(f"Result: {winner_str}")
if penalty:
print(f"Penalty applied: {penalty}")
else:
print(f"Game over reason: {board.result()} (unknown termination)")
z = 0
print(f"Penalty applied: {z}")
else:
print("Game reached move limit - applying increased penalty")
print("Result: No winner (move limit reached)")
z = -2.0
print(f"Penalty applied: {z}")
game_data = [(s, p, z) for (s, p, v) in game_data]
if save_svg:
save_board_svg(board, f"{svg_prefix}_final.svg")
return game_data
def train_step(agent, buffer, optimizer, batch_size=32, early_stopping=None, patience=5, min_delta=1e-3):
if len(buffer.buffer) < batch_size:
return
batch = buffer.sample(batch_size)
states, target_policies, target_values = buffer.get_tensors(batch)
agent.model.train()
optimizer.zero_grad()
pred_policies, pred_values = agent.model(states)
# Policy loss (cross-entropy)
policy_loss = -torch.sum(target_policies * torch.log_softmax(pred_policies, dim=1)) / batch_size
# Value loss (MSE)
value_loss = nn.functional.mse_loss(pred_values, target_values)
loss = policy_loss + value_loss
loss.backward()
optimizer.step()
print(f"Train step: loss={loss.item():.4f} (policy={policy_loss.item():.4f}, value={value_loss.item():.4f})")
# Early stopping logic (if enabled)
if early_stopping is not None:
if loss.item() < early_stopping['best_loss'] - min_delta:
early_stopping['best_loss'] = loss.item()
early_stopping['epochs_no_improve'] = 0
else:
early_stopping['epochs_no_improve'] += 1
if early_stopping['epochs_no_improve'] >= patience:
print("Early stopping triggered: no improvement.")
return 'stop'
def main():
agent = Agent()
# Try to load pretrained weights if available
import os
from o1.utils import load_model
pretrained_path = "trained_agent.pth"
if os.path.exists(pretrained_path):
print(f"Loading pretrained weights from {pretrained_path}...")
load_model(agent, pretrained_path)
else:
print("No pretrained weights found. Training from scratch.")
buffer = ExperienceBuffer()
optimizer = optim.Adam(agent.model.parameters(), lr=1e-4)
num_games = 10 # Increased from 50 for more training data
global_reward = 0
for i in range(num_games):
print(f"Self-play game {i+1}")
# Only save video for the last game
save_video = (i == num_games - 1)
game_experience = self_play_game(agent, simulations=10, max_moves=150,
save_svg=save_video,
svg_prefix=f"final_game")
for exp in game_experience:
buffer.add(exp)
# Log the reward for this game (all z are the same for the game)
if game_experience:
game_reward = game_experience[0][2]
global_reward += game_reward
print(f"Reward for this game: {game_reward}")
print(f"Cumulative global reward: {global_reward}")
train_step(agent, buffer, optimizer)
print("Pipeline complete. Self-play now uses MCTS for move selection and real learning.")
# Save the trained model at the end
save_model(agent, "trained_agent.pth")
print("Model saved as trained_agent.pth")
if __name__ == "__main__":
main()