play-with-o1 / src /o1 /mcts.py
FlameF0X's picture
Upload 11 files
8806ce1 verified
"""
Monte Carlo Tree Search (MCTS) for o1 agent.
Basic implementation: runs simulations, selects moves by visit count.
Integrate with neural net for policy/value guidance for full strength.
"""
import chess
import random
from collections import defaultdict
import torch
class MCTSNode:
def __init__(self, board, parent=None, move=None):
self.board = board.copy()
self.parent = parent
self.move = move
self.children = []
self.visits = 0
self.value = 0.0
self.untried_moves = list(board.legal_moves)
def is_fully_expanded(self):
return len(self.untried_moves) == 0
def best_child(self, c_param=1.4):
choices = [
(child.value / (child.visits + 1e-6) + c_param * ( (2 * (self.visits + 1e-6)) ** 0.5 / (child.visits + 1e-6) ), child)
for child in self.children
]
return max(choices, key=lambda x: x[0])[1]
class MCTS:
def __init__(self, agent=None, simulations=50):
self.agent = agent
self.simulations = simulations
def search(self, board, restrict_top_n=None):
root = MCTSNode(board)
for _ in range(self.simulations):
node = root
sim_board = board.copy()
# Selection
while node.is_fully_expanded() and node.children:
node = node.best_child()
sim_board.push(node.move)
# Expansion
if node.untried_moves:
move = random.choice(node.untried_moves)
sim_board.push(move)
child = MCTSNode(sim_board, parent=node, move=move)
node.children.append(child)
node.untried_moves.remove(move)
node = child
# Simulation
result = self.simulate(sim_board)
# Backpropagation
# If it's black's turn at the node, invert the value for correct perspective
invert = False
temp_node = node
while temp_node.parent is not None:
temp_node = temp_node.parent
invert = not invert
value = -result if invert else result
while node:
node.visits += 1
node.value += value
node = node.parent
# Choose move with most visits, but restrict to top-N if specified
if not root.children:
return random.choice(list(board.legal_moves))
children_sorted = sorted(root.children, key=lambda c: c.visits, reverse=True)
if restrict_top_n is not None and restrict_top_n < len(children_sorted):
# Only consider top-N moves
children_sorted = children_sorted[:restrict_top_n]
best = max(children_sorted, key=lambda c: c.visits)
return best.move
def simulate(self, board, use_diffusion=True, diffusion_steps=10, noise_scale=1.0):
# Use neural network to evaluate the board instead of random playout
if self.agent is not None:
with torch.no_grad():
if use_diffusion and hasattr(self.agent, 'predict_with_diffusion'):
_, value = self.agent.predict_with_diffusion(board, steps=diffusion_steps, noise_scale=noise_scale)
else:
_, value = self.agent.predict(board)
return value.item()
# Fallback: play random moves until game ends
while not board.is_game_over():
move = random.choice(list(board.legal_moves))
board.push(move)
result = board.result()
if result == '1-0':
return 1
elif result == '0-1':
return -1
else:
return 0