Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- src/mcts.py +10 -1
- src/o2_agent.py +10 -2
src/mcts.py
CHANGED
@@ -20,7 +20,7 @@ class MCTS:
|
|
20 |
self.simulations = simulations
|
21 |
self.c_puct = c_puct
|
22 |
|
23 |
-
def run(self, board):
|
24 |
root = MCTSNode(board)
|
25 |
self._expand(root)
|
26 |
for _ in range(self.simulations):
|
@@ -45,6 +45,15 @@ class MCTS:
|
|
45 |
n.W += value
|
46 |
n.Q = n.W / n.N
|
47 |
value = -value # Switch perspective
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
# Choose move with highest visit count
|
49 |
best_move = max(root.children.items(), key=lambda item: item[1].N)[0]
|
50 |
return best_move
|
|
|
20 |
self.simulations = simulations
|
21 |
self.c_puct = c_puct
|
22 |
|
23 |
+
def run(self, board, temperature=0.0):
|
24 |
root = MCTSNode(board)
|
25 |
self._expand(root)
|
26 |
for _ in range(self.simulations):
|
|
|
45 |
n.W += value
|
46 |
n.Q = n.W / n.N
|
47 |
value = -value # Switch perspective
|
48 |
+
# Temperature-based sampling for opening diversity
|
49 |
+
if temperature and temperature > 0:
|
50 |
+
import numpy as np
|
51 |
+
moves = list(root.children.keys())
|
52 |
+
visits = np.array([root.children[m].N for m in moves], dtype=np.float32)
|
53 |
+
probs = visits ** (1.0 / temperature)
|
54 |
+
probs = probs / np.sum(probs)
|
55 |
+
move = np.random.choice(moves, p=probs)
|
56 |
+
return move
|
57 |
# Choose move with highest visit count
|
58 |
best_move = max(root.children.items(), key=lambda item: item[1].N)[0]
|
59 |
return best_move
|
src/o2_agent.py
CHANGED
@@ -25,10 +25,10 @@ class O2Agent:
|
|
25 |
self.model.load_state_dict(torch.load(model_path))
|
26 |
self.model.eval()
|
27 |
|
28 |
-
def select_move(self, board, use_mcts=True, simulations=100):
|
29 |
if use_mcts:
|
30 |
mcts = MCTS(self.model, simulations=simulations)
|
31 |
-
return mcts.run(board)
|
32 |
tensor = torch.tensor(board_to_tensor(board)).unsqueeze(0)
|
33 |
with torch.no_grad():
|
34 |
policy, _ = self.model(tensor)
|
@@ -37,6 +37,14 @@ class O2Agent:
|
|
37 |
for move in legal_moves:
|
38 |
move_idx = self.move_to_index(move)
|
39 |
move_scores.append(policy[0, move_idx].item())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
best_move = legal_moves[int(torch.tensor(move_scores).argmax())]
|
41 |
return best_move
|
42 |
|
|
|
25 |
self.model.load_state_dict(torch.load(model_path))
|
26 |
self.model.eval()
|
27 |
|
28 |
+
def select_move(self, board, use_mcts=True, simulations=100, temperature=0.0):
|
29 |
if use_mcts:
|
30 |
mcts = MCTS(self.model, simulations=simulations)
|
31 |
+
return mcts.run(board, temperature=temperature)
|
32 |
tensor = torch.tensor(board_to_tensor(board)).unsqueeze(0)
|
33 |
with torch.no_grad():
|
34 |
policy, _ = self.model(tensor)
|
|
|
37 |
for move in legal_moves:
|
38 |
move_idx = self.move_to_index(move)
|
39 |
move_scores.append(policy[0, move_idx].item())
|
40 |
+
if temperature and temperature > 0:
|
41 |
+
# Softmax sampling
|
42 |
+
import numpy as np
|
43 |
+
scores = np.array(move_scores)
|
44 |
+
exp_scores = np.exp(scores / temperature)
|
45 |
+
probs = exp_scores / np.sum(exp_scores)
|
46 |
+
move = np.random.choice(legal_moves, p=probs)
|
47 |
+
return move
|
48 |
best_move = legal_moves[int(torch.tensor(move_scores).argmax())]
|
49 |
return best_move
|
50 |
|