import numpy as np import chess import torch from o2_model import O2Net, board_to_tensor from mcts import MCTS import random # Optional: Endgame tablebase and opening book integration placeholders # You can use python-chess's tablebase and opening book modules if desired # Example for endgame tablebase: # from chess import tablebase # tb = tablebase.Tablebase() # tb.add_tablebase('/path/to/syzygy') # if tb.probe_wdl(board) is not None: # # Use tablebase move # Example for opening book: # from chess.polyglot import open_reader # with open_reader('book.bin') as reader: # entry = reader.find(board) # move = entry.move class O2Agent: def __init__(self, model_path=None): self.model = O2Net() if model_path: self.model.load_state_dict(torch.load(model_path)) self.model.eval() def select_move(self, board, use_mcts=True, simulations=100, temperature=0.0): if use_mcts: mcts = MCTS(self.model, simulations=simulations) return mcts.run(board, temperature=temperature) tensor = torch.tensor(board_to_tensor(board)).unsqueeze(0) with torch.no_grad(): policy, _ = self.model(tensor) legal_moves = list(board.legal_moves) move_scores = [] for move in legal_moves: move_idx = self.move_to_index(move) move_scores.append(policy[0, move_idx].item()) if temperature and temperature > 0: # Softmax sampling scores = np.array(move_scores) exp_scores = np.exp(scores / temperature) probs = exp_scores / np.sum(exp_scores) move = np.random.choice(legal_moves, p=probs) return move best_move = legal_moves[int(torch.tensor(move_scores).argmax())] return best_move def move_to_index(self, move): # Encode move as from_square * 64 + to_square + promotion_offset from_square = move.from_square to_square = move.to_square promotion = move.promotion if move.promotion else 0 promotion_offset = 0 if promotion: # Promotion: 1=Knight, 2=Bishop, 3=Rook, 4=Queen (python-chess) # Offset: 4096 + (promotion-1)*64*64//4 promotion_offset = 4096 + (promotion - 1) * 256 idx = from_square * 64 + to_square + promotion_offset # Ensure index is within bounds return idx if idx < 4672 else idx % 4672 def index_to_move(self, board, index): # Decode index to move (reverse of move_to_index) if index >= 4096: promotion = (index - 4096) % 4 + 1 idx = index - 4096 from_square = idx // 64 to_square = idx % 64 move = chess.Move(from_square, to_square, promotion=promotion) else: from_square = index // 64 to_square = index % 64 move = chess.Move(from_square, to_square) if move in board.legal_moves: return move # Fallback: pick a random legal move return random.choice(list(board.legal_moves)) if __name__ == "__main__": board = chess.Board() agent = O2Agent() move = agent.select_move(board) print("O2 selects:", move)