import chess import torch import numpy as np from o2_model import board_to_tensor class MCTSNode: def __init__(self, board, parent=None, move=None): self.board = board.copy() self.parent = parent self.move = move self.children = {} self.N = 0 # Visit count self.W = 0 # Total value self.Q = 0 # Mean value self.P = 0 # Prior probability class MCTS: def __init__(self, model, simulations=100, c_puct=1.5): self.model = model self.simulations = simulations self.c_puct = c_puct def run(self, board, temperature=0.0): root = MCTSNode(board) self._expand(root) for _ in range(self.simulations): node = root search_path = [node] # Selection while node.children: max_ucb = -float('inf') best_move = None for move, child in node.children.items(): ucb = child.Q + self.c_puct * child.P * np.sqrt(node.N) / (1 + child.N) if ucb > max_ucb: max_ucb = ucb best_move = move node = node.children[best_move] search_path.append(node) # Expansion value = self._expand(node) # Backpropagation for n in reversed(search_path): n.N += 1 n.W += value n.Q = n.W / n.N if n.N > 0 else 0.0 value = -value # Switch perspective # Temperature-based sampling for opening diversity if temperature and temperature > 0: moves = list(root.children.keys()) visits = np.array([root.children[m].N for m in moves], dtype=np.float32) probs = visits ** (1.0 / temperature) probs = probs / np.sum(probs) move = np.random.choice(moves, p=probs) return move # Choose move with highest visit count best_move = max(root.children.items(), key=lambda item: item[1].N)[0] return best_move def _expand(self, node): if node.board.is_game_over(): result = node.board.result() if result == '1-0': return 1 elif result == '0-1': return -1 else: return 0 tensor = torch.tensor(board_to_tensor(node.board)).unsqueeze(0) with torch.no_grad(): policy, value = self.model(tensor) policy = torch.softmax(policy, dim=1).numpy()[0] assert len(policy) == 4672, f"Policy size mismatch: expected 4672, got {len(policy)}" legal_moves = list(node.board.legal_moves) total_p = 1e-8 # Small epsilon to prevent division by zero for move in legal_moves: try: idx = self.move_to_index(move) if 0 <= idx < 4672: # Ensure index is within bounds p = policy[idx] total_p += p except Exception: continue # Skip moves that can't be indexed properly if total_p < 1e-8: # If all probabilities are extremely small total_p = 1.0 # Fall back to uniform distribution # Use uniform distribution only for legal moves for move in legal_moves: idx = self.move_to_index(move) if 0 <= idx < 4672: policy[idx] = 1.0 / len(legal_moves) # Create child nodes only for valid moves for move in legal_moves: try: idx = self.move_to_index(move) if 0 <= idx < 4672: p = policy[idx] / total_p child_board = node.board.copy() child_board.push(move) child = MCTSNode(child_board, parent=node, move=move) child.P = p node.children[move] = child except Exception: continue # Skip problematic moves return value.item() def move_to_index(self, move): from_square = move.from_square to_square = move.to_square promotion = move.promotion if move.promotion else 0 # Base index for normal moves idx = from_square * 64 + to_square # Handle promotions (knight=1, bishop=2, rook=3, queen=4) if promotion: # Map to indices after normal moves (4096 onwards) idx = 4096 + ((promotion - 1) * 64 * 64 // 4) + (from_square * 8 + to_square // 8) # Ensure index is within bounds (4672 = 64*64 + 64*8) return min(idx, 4671)