File size: 3,825 Bytes
8806ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""

Monte Carlo Tree Search (MCTS) for o1 agent.

Basic implementation: runs simulations, selects moves by visit count.

Integrate with neural net for policy/value guidance for full strength.

"""
import chess
import random
from collections import defaultdict
import torch

class MCTSNode:
    def __init__(self, board, parent=None, move=None):
        self.board = board.copy()
        self.parent = parent
        self.move = move
        self.children = []
        self.visits = 0
        self.value = 0.0
        self.untried_moves = list(board.legal_moves)

    def is_fully_expanded(self):
        return len(self.untried_moves) == 0

    def best_child(self, c_param=1.4):
        choices = [
            (child.value / (child.visits + 1e-6) + c_param * ( (2 * (self.visits + 1e-6)) ** 0.5 / (child.visits + 1e-6) ), child)
            for child in self.children
        ]
        return max(choices, key=lambda x: x[0])[1]

class MCTS:
    def __init__(self, agent=None, simulations=50):
        self.agent = agent
        self.simulations = simulations

    def search(self, board, restrict_top_n=None):
        root = MCTSNode(board)
        for _ in range(self.simulations):
            node = root
            sim_board = board.copy()
            # Selection
            while node.is_fully_expanded() and node.children:
                node = node.best_child()
                sim_board.push(node.move)
            # Expansion
            if node.untried_moves:
                move = random.choice(node.untried_moves)
                sim_board.push(move)
                child = MCTSNode(sim_board, parent=node, move=move)
                node.children.append(child)
                node.untried_moves.remove(move)
                node = child
            # Simulation
            result = self.simulate(sim_board)
            # Backpropagation
            # If it's black's turn at the node, invert the value for correct perspective
            invert = False
            temp_node = node
            while temp_node.parent is not None:
                temp_node = temp_node.parent
                invert = not invert
            value = -result if invert else result
            while node:
                node.visits += 1
                node.value += value
                node = node.parent
        # Choose move with most visits, but restrict to top-N if specified
        if not root.children:
            return random.choice(list(board.legal_moves))
        children_sorted = sorted(root.children, key=lambda c: c.visits, reverse=True)
        if restrict_top_n is not None and restrict_top_n < len(children_sorted):
            # Only consider top-N moves
            children_sorted = children_sorted[:restrict_top_n]
        best = max(children_sorted, key=lambda c: c.visits)
        return best.move

    def simulate(self, board, use_diffusion=True, diffusion_steps=10, noise_scale=1.0):
        # Use neural network to evaluate the board instead of random playout
        if self.agent is not None:
            with torch.no_grad():
                if use_diffusion and hasattr(self.agent, 'predict_with_diffusion'):
                    _, value = self.agent.predict_with_diffusion(board, steps=diffusion_steps, noise_scale=noise_scale)
                else:
                    _, value = self.agent.predict(board)
                return value.item()
        # Fallback: play random moves until game ends
        while not board.is_game_over():
            move = random.choice(list(board.legal_moves))
            board.push(move)
        result = board.result()
        if result == '1-0':
            return 1
        elif result == '0-1':
            return -1
        else:
            return 0