FlameF0X commited on
Commit
5342d4d
·
verified ·
1 Parent(s): ce00937

Upload 2 files

Browse files
Files changed (2) hide show
  1. src/mcts.py +10 -1
  2. src/o2_agent.py +10 -2
src/mcts.py CHANGED
@@ -20,7 +20,7 @@ class MCTS:
20
  self.simulations = simulations
21
  self.c_puct = c_puct
22
 
23
- def run(self, board):
24
  root = MCTSNode(board)
25
  self._expand(root)
26
  for _ in range(self.simulations):
@@ -45,6 +45,15 @@ class MCTS:
45
  n.W += value
46
  n.Q = n.W / n.N
47
  value = -value # Switch perspective
 
 
 
 
 
 
 
 
 
48
  # Choose move with highest visit count
49
  best_move = max(root.children.items(), key=lambda item: item[1].N)[0]
50
  return best_move
 
20
  self.simulations = simulations
21
  self.c_puct = c_puct
22
 
23
+ def run(self, board, temperature=0.0):
24
  root = MCTSNode(board)
25
  self._expand(root)
26
  for _ in range(self.simulations):
 
45
  n.W += value
46
  n.Q = n.W / n.N
47
  value = -value # Switch perspective
48
+ # Temperature-based sampling for opening diversity
49
+ if temperature and temperature > 0:
50
+ import numpy as np
51
+ moves = list(root.children.keys())
52
+ visits = np.array([root.children[m].N for m in moves], dtype=np.float32)
53
+ probs = visits ** (1.0 / temperature)
54
+ probs = probs / np.sum(probs)
55
+ move = np.random.choice(moves, p=probs)
56
+ return move
57
  # Choose move with highest visit count
58
  best_move = max(root.children.items(), key=lambda item: item[1].N)[0]
59
  return best_move
src/o2_agent.py CHANGED
@@ -25,10 +25,10 @@ class O2Agent:
25
  self.model.load_state_dict(torch.load(model_path))
26
  self.model.eval()
27
 
28
- def select_move(self, board, use_mcts=True, simulations=100):
29
  if use_mcts:
30
  mcts = MCTS(self.model, simulations=simulations)
31
- return mcts.run(board)
32
  tensor = torch.tensor(board_to_tensor(board)).unsqueeze(0)
33
  with torch.no_grad():
34
  policy, _ = self.model(tensor)
@@ -37,6 +37,14 @@ class O2Agent:
37
  for move in legal_moves:
38
  move_idx = self.move_to_index(move)
39
  move_scores.append(policy[0, move_idx].item())
 
 
 
 
 
 
 
 
40
  best_move = legal_moves[int(torch.tensor(move_scores).argmax())]
41
  return best_move
42
 
 
25
  self.model.load_state_dict(torch.load(model_path))
26
  self.model.eval()
27
 
28
+ def select_move(self, board, use_mcts=True, simulations=100, temperature=0.0):
29
  if use_mcts:
30
  mcts = MCTS(self.model, simulations=simulations)
31
+ return mcts.run(board, temperature=temperature)
32
  tensor = torch.tensor(board_to_tensor(board)).unsqueeze(0)
33
  with torch.no_grad():
34
  policy, _ = self.model(tensor)
 
37
  for move in legal_moves:
38
  move_idx = self.move_to_index(move)
39
  move_scores.append(policy[0, move_idx].item())
40
+ if temperature and temperature > 0:
41
+ # Softmax sampling
42
+ import numpy as np
43
+ scores = np.array(move_scores)
44
+ exp_scores = np.exp(scores / temperature)
45
+ probs = exp_scores / np.sum(exp_scores)
46
+ move = np.random.choice(legal_moves, p=probs)
47
+ return move
48
  best_move = legal_moves[int(torch.tensor(move_scores).argmax())]
49
  return best_move
50