Spaces:
Build error
Build error
""" | |
Monte Carlo Tree Search (MCTS) for o1 agent. | |
Basic implementation: runs simulations, selects moves by visit count. | |
Integrate with neural net for policy/value guidance for full strength. | |
""" | |
import chess | |
import random | |
from collections import defaultdict | |
import torch | |
class MCTSNode: | |
def __init__(self, board, parent=None, move=None): | |
self.board = board.copy() | |
self.parent = parent | |
self.move = move | |
self.children = [] | |
self.visits = 0 | |
self.value = 0.0 | |
self.untried_moves = list(board.legal_moves) | |
def is_fully_expanded(self): | |
return len(self.untried_moves) == 0 | |
def best_child(self, c_param=1.4): | |
choices = [ | |
(child.value / (child.visits + 1e-6) + c_param * ( (2 * (self.visits + 1e-6)) ** 0.5 / (child.visits + 1e-6) ), child) | |
for child in self.children | |
] | |
return max(choices, key=lambda x: x[0])[1] | |
class MCTS: | |
def __init__(self, agent=None, simulations=50): | |
self.agent = agent | |
self.simulations = simulations | |
def search(self, board, restrict_top_n=None): | |
root = MCTSNode(board) | |
for _ in range(self.simulations): | |
node = root | |
sim_board = board.copy() | |
# Selection | |
while node.is_fully_expanded() and node.children: | |
node = node.best_child() | |
sim_board.push(node.move) | |
# Expansion | |
if node.untried_moves: | |
move = random.choice(node.untried_moves) | |
sim_board.push(move) | |
child = MCTSNode(sim_board, parent=node, move=move) | |
node.children.append(child) | |
node.untried_moves.remove(move) | |
node = child | |
# Simulation | |
result = self.simulate(sim_board) | |
# Backpropagation | |
# If it's black's turn at the node, invert the value for correct perspective | |
invert = False | |
temp_node = node | |
while temp_node.parent is not None: | |
temp_node = temp_node.parent | |
invert = not invert | |
value = -result if invert else result | |
while node: | |
node.visits += 1 | |
node.value += value | |
node = node.parent | |
# Choose move with most visits, but restrict to top-N if specified | |
if not root.children: | |
return random.choice(list(board.legal_moves)) | |
children_sorted = sorted(root.children, key=lambda c: c.visits, reverse=True) | |
if restrict_top_n is not None and restrict_top_n < len(children_sorted): | |
# Only consider top-N moves | |
children_sorted = children_sorted[:restrict_top_n] | |
best = max(children_sorted, key=lambda c: c.visits) | |
return best.move | |
def simulate(self, board, use_diffusion=True, diffusion_steps=10, noise_scale=1.0): | |
# Use neural network to evaluate the board instead of random playout | |
if self.agent is not None: | |
with torch.no_grad(): | |
if use_diffusion and hasattr(self.agent, 'predict_with_diffusion'): | |
_, value = self.agent.predict_with_diffusion(board, steps=diffusion_steps, noise_scale=noise_scale) | |
else: | |
_, value = self.agent.predict(board) | |
return value.item() | |
# Fallback: play random moves until game ends | |
while not board.is_game_over(): | |
move = random.choice(list(board.legal_moves)) | |
board.push(move) | |
result = board.result() | |
if result == '1-0': | |
return 1 | |
elif result == '0-1': | |
return -1 | |
else: | |
return 0 | |