Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files- src/app.py +277 -0
- src/mcts.py +40 -17
- src/o2_agent.py +86 -88
src/app.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
st.set_page_config(page_title="Play Chess vs o2", layout="centered")
|
3 |
+
|
4 |
+
import chess
|
5 |
+
import chess.svg
|
6 |
+
import torch
|
7 |
+
from o2_model import O2Net, board_to_tensor
|
8 |
+
from o2_agent import O2Agent
|
9 |
+
from PIL import Image
|
10 |
+
import io
|
11 |
+
import base64
|
12 |
+
import os
|
13 |
+
import chess.pgn
|
14 |
+
import requests
|
15 |
+
import random
|
16 |
+
import re
|
17 |
+
import tempfile
|
18 |
+
|
19 |
+
# Use temp directory for cache/model paths
|
20 |
+
MODEL_CACHE_DIR = tempfile.gettempdir()
|
21 |
+
MODEL_REPO = "FlameF0X/o2"
|
22 |
+
MODEL_FILENAME = "o2_agent.pth"
|
23 |
+
|
24 |
+
def ensure_model():
|
25 |
+
from huggingface_hub import hf_hub_download
|
26 |
+
# Always download to cache and return the path
|
27 |
+
return hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME, cache_dir=MODEL_CACHE_DIR)
|
28 |
+
|
29 |
+
# Load model with better error handling
|
30 |
+
@st.cache_resource
|
31 |
+
def load_agent():
|
32 |
+
try:
|
33 |
+
model_path = ensure_model()
|
34 |
+
if model_path is None:
|
35 |
+
return None
|
36 |
+
|
37 |
+
agent = O2Agent()
|
38 |
+
agent.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
39 |
+
agent.model.eval()
|
40 |
+
return agent
|
41 |
+
except Exception as e:
|
42 |
+
st.error(f"Error loading agent: {e}")
|
43 |
+
return None
|
44 |
+
|
45 |
+
def render_svg(svg):
|
46 |
+
b64 = base64.b64encode(svg.encode('utf-8')).decode('utf-8')
|
47 |
+
return f"<img src='data:image/svg+xml;base64,{b64}'/>", b64
|
48 |
+
|
49 |
+
# --- Move parsing utility ---
|
50 |
+
def parse_move_input(move_input, board):
|
51 |
+
if not move_input:
|
52 |
+
return None
|
53 |
+
move_input = move_input.strip()
|
54 |
+
if len(move_input) >= 4 and move_input[2:4].isalnum():
|
55 |
+
try:
|
56 |
+
move = chess.Move.from_uci(move_input.lower())
|
57 |
+
if move in board.legal_moves:
|
58 |
+
return move
|
59 |
+
except:
|
60 |
+
pass
|
61 |
+
try:
|
62 |
+
move = board.parse_san(move_input)
|
63 |
+
if move in board.legal_moves:
|
64 |
+
return move
|
65 |
+
except:
|
66 |
+
pass
|
67 |
+
try:
|
68 |
+
variations = [move_input.upper(), move_input.lower(), move_input.capitalize()]
|
69 |
+
for variation in variations:
|
70 |
+
try:
|
71 |
+
move = board.parse_san(variation)
|
72 |
+
if move in board.legal_moves:
|
73 |
+
return move
|
74 |
+
except:
|
75 |
+
continue
|
76 |
+
except:
|
77 |
+
pass
|
78 |
+
if len(move_input) == 2 and move_input[0].lower() in 'abcdefgh' and move_input[1] in '12345678':
|
79 |
+
try:
|
80 |
+
move = board.parse_san(move_input.lower())
|
81 |
+
if move in board.legal_moves:
|
82 |
+
return move
|
83 |
+
except:
|
84 |
+
pass
|
85 |
+
castling_variations = {
|
86 |
+
'0-0': 'O-O', '0-0-0': 'O-O-O', 'oo': 'O-O', 'ooo': 'O-O-O', 'o-o': 'O-O', 'o-o-o': 'O-O-O',
|
87 |
+
}
|
88 |
+
lower_input = move_input.lower()
|
89 |
+
if lower_input in castling_variations:
|
90 |
+
try:
|
91 |
+
move = board.parse_san(castling_variations[lower_input])
|
92 |
+
if move in board.legal_moves:
|
93 |
+
return move
|
94 |
+
except:
|
95 |
+
pass
|
96 |
+
return None
|
97 |
+
|
98 |
+
# --- Main UI ---
|
99 |
+
if "board" not in st.session_state:
|
100 |
+
st.session_state.board = chess.Board()
|
101 |
+
if "history" not in st.session_state:
|
102 |
+
st.session_state.history = []
|
103 |
+
|
104 |
+
# Load agent with error handling
|
105 |
+
agent = None
|
106 |
+
agent_loaded = False
|
107 |
+
|
108 |
+
with st.spinner("Loading o2 model..."):
|
109 |
+
try:
|
110 |
+
agent = load_agent()
|
111 |
+
if agent is not None:
|
112 |
+
agent_loaded = True
|
113 |
+
st.success("o2 model loaded successfully!")
|
114 |
+
else:
|
115 |
+
st.warning("Failed to load o2 model. Using random moves for AI.")
|
116 |
+
except Exception as e:
|
117 |
+
st.error(f"Failed to load o2: {e}")
|
118 |
+
st.warning("Using random moves for AI.")
|
119 |
+
|
120 |
+
board = st.session_state.board
|
121 |
+
history = st.session_state.history
|
122 |
+
|
123 |
+
st.title("♟️ Play Chess vs o2")
|
124 |
+
|
125 |
+
if not agent_loaded:
|
126 |
+
st.info("🎲 AI is using random moves (o2 model not available)")
|
127 |
+
|
128 |
+
if st.button("Reset Game"):
|
129 |
+
st.session_state.board = chess.Board()
|
130 |
+
st.session_state.history = []
|
131 |
+
st.rerun()
|
132 |
+
|
133 |
+
# Create two columns for layout
|
134 |
+
col_board, col_pgn = st.columns([2, 1])
|
135 |
+
|
136 |
+
with col_board:
|
137 |
+
board_placeholder = st.empty()
|
138 |
+
|
139 |
+
def render_board():
|
140 |
+
try:
|
141 |
+
last_move = board.peek() if board.move_stack else None
|
142 |
+
svg_board = chess.svg.board(board=board, lastmove=last_move, size=400)
|
143 |
+
board_placeholder.markdown(f'<div style="display: flex; justify-content: center;">{svg_board}</div>', unsafe_allow_html=True)
|
144 |
+
except Exception as e:
|
145 |
+
st.error(f"Error rendering board: {e}")
|
146 |
+
|
147 |
+
render_board()
|
148 |
+
|
149 |
+
col1, col2 = st.columns(2)
|
150 |
+
with col1:
|
151 |
+
st.write(f"**Turn:** {'White' if board.turn == chess.WHITE else 'Black'}")
|
152 |
+
with col2:
|
153 |
+
if board.is_check():
|
154 |
+
st.write("**Check!**")
|
155 |
+
|
156 |
+
with col_pgn:
|
157 |
+
st.write("### Game History")
|
158 |
+
pgn_placeholder = st.empty()
|
159 |
+
|
160 |
+
def render_pgn():
|
161 |
+
if history:
|
162 |
+
try:
|
163 |
+
game = chess.pgn.Game()
|
164 |
+
game.headers["Event"] = "Human vs o2"
|
165 |
+
game.headers["White"] = "Human"
|
166 |
+
game.headers["Black"] = "o2" if agent_loaded else "Random AI"
|
167 |
+
node = game
|
168 |
+
temp_board = chess.Board()
|
169 |
+
for uci in history:
|
170 |
+
move = chess.Move.from_uci(uci)
|
171 |
+
if move in temp_board.legal_moves:
|
172 |
+
node = node.add_main_variation(move)
|
173 |
+
temp_board.push(move)
|
174 |
+
else:
|
175 |
+
break
|
176 |
+
pgn_placeholder.code(str(game), language="pgn")
|
177 |
+
except Exception as e:
|
178 |
+
move_pairs = []
|
179 |
+
for i in range(0, len(history), 2):
|
180 |
+
white_move = history[i]
|
181 |
+
black_move = history[i+1] if i+1 < len(history) else ""
|
182 |
+
move_pairs.append(f"{i//2 + 1}. {white_move} {black_move}")
|
183 |
+
pgn_placeholder.code("\n".join(move_pairs))
|
184 |
+
else:
|
185 |
+
pgn_placeholder.text("No moves yet")
|
186 |
+
|
187 |
+
render_pgn()
|
188 |
+
|
189 |
+
if not board.is_game_over() and board.turn == chess.WHITE:
|
190 |
+
st.write("### Your Turn (White)")
|
191 |
+
legal_moves = list(board.legal_moves)
|
192 |
+
legal_moves_uci = [move.uci() for move in legal_moves]
|
193 |
+
legal_moves_san = []
|
194 |
+
for move in legal_moves:
|
195 |
+
try:
|
196 |
+
san = board.san(move)
|
197 |
+
legal_moves_san.append(san)
|
198 |
+
except:
|
199 |
+
legal_moves_san.append(move.uci())
|
200 |
+
with st.expander("Show legal moves"):
|
201 |
+
col1, col2 = st.columns(2)
|
202 |
+
with col1:
|
203 |
+
st.write("**Algebraic notation:**")
|
204 |
+
st.write(", ".join(sorted(legal_moves_san)))
|
205 |
+
with col2:
|
206 |
+
st.write("**UCI notation:**")
|
207 |
+
st.write(", ".join(sorted(legal_moves_uci)))
|
208 |
+
user_move = st.text_input("Enter your move (e.g., E4, Nf3, e2e4, O-O):", key="move_input", help="You can use algebraic notation (E4, Nf3) or UCI notation (e2e4). Case doesn't matter!")
|
209 |
+
col1, col2 = st.columns(2)
|
210 |
+
with col1:
|
211 |
+
if st.button("Submit Move"):
|
212 |
+
if user_move:
|
213 |
+
parsed_move = parse_move_input(user_move, board)
|
214 |
+
if parsed_move:
|
215 |
+
try:
|
216 |
+
board.push(parsed_move)
|
217 |
+
history.append(parsed_move.uci())
|
218 |
+
render_board() # Update board immediately
|
219 |
+
render_pgn() # Update PGN immediately
|
220 |
+
st.success(f"You played: {board.san(parsed_move)} ({parsed_move.uci()})")
|
221 |
+
st.rerun()
|
222 |
+
except Exception as e:
|
223 |
+
st.error(f"Error making move: {e}")
|
224 |
+
else:
|
225 |
+
st.warning(f"Invalid move: '{user_move}'. Please check the legal moves above.")
|
226 |
+
else:
|
227 |
+
st.warning("Please enter a move.")
|
228 |
+
with col2:
|
229 |
+
if st.button("Random Move"):
|
230 |
+
if legal_moves:
|
231 |
+
random_move = random.choice(legal_moves)
|
232 |
+
board.push(random_move)
|
233 |
+
history.append(random_move.uci())
|
234 |
+
render_board() # Update board immediately
|
235 |
+
render_pgn() # Update PGN immediately
|
236 |
+
st.rerun()
|
237 |
+
|
238 |
+
if not board.is_game_over() and board.turn == chess.BLACK:
|
239 |
+
st.write("### o2's Turn (Black)")
|
240 |
+
with st.spinner("o2 is thinking..."):
|
241 |
+
try:
|
242 |
+
if agent_loaded and agent:
|
243 |
+
# Use temperature sampling for first 10 moves, then greedy
|
244 |
+
if len(history) < 20: # 10 moves per side
|
245 |
+
move = agent.select_move(board, use_mcts=True, simulations=30, temperature=1.2)
|
246 |
+
else:
|
247 |
+
move = agent.select_move(board, use_mcts=True, simulations=30, temperature=0.0)
|
248 |
+
else:
|
249 |
+
legal_moves = list(board.legal_moves)
|
250 |
+
move = random.choice(legal_moves) if legal_moves else None
|
251 |
+
if move and move in board.legal_moves:
|
252 |
+
move_san = board.san(move)
|
253 |
+
board.push(move)
|
254 |
+
history.append(move.uci())
|
255 |
+
st.success(f"o2 played: {move_san} ({move.uci()})")
|
256 |
+
st.rerun()
|
257 |
+
else:
|
258 |
+
st.error("o2 couldn't find a valid move")
|
259 |
+
except Exception as e:
|
260 |
+
st.error(f"Error during o2 move: {e}")
|
261 |
+
|
262 |
+
if board.is_game_over():
|
263 |
+
st.write("### Game Over!")
|
264 |
+
result = board.result()
|
265 |
+
outcome = board.outcome()
|
266 |
+
if result == "1-0":
|
267 |
+
st.success("White wins!")
|
268 |
+
elif result == "0-1":
|
269 |
+
st.error("Black wins!")
|
270 |
+
else:
|
271 |
+
st.info("Draw!")
|
272 |
+
st.write(f"**Result:** {result}")
|
273 |
+
st.write(f"**Termination:** {outcome.termination.name}")
|
274 |
+
if st.button("Start New Game"):
|
275 |
+
st.session_state.board = chess.Board()
|
276 |
+
st.session_state.history = []
|
277 |
+
st.rerun()
|
src/mcts.py
CHANGED
@@ -43,11 +43,10 @@ class MCTS:
|
|
43 |
for n in reversed(search_path):
|
44 |
n.N += 1
|
45 |
n.W += value
|
46 |
-
n.Q = n.W / n.N
|
47 |
value = -value # Switch perspective
|
48 |
# Temperature-based sampling for opening diversity
|
49 |
if temperature and temperature > 0:
|
50 |
-
import numpy as np
|
51 |
moves = list(root.children.keys())
|
52 |
visits = np.array([root.children[m].N for m in moves], dtype=np.float32)
|
53 |
probs = visits ** (1.0 / temperature)
|
@@ -71,27 +70,51 @@ class MCTS:
|
|
71 |
with torch.no_grad():
|
72 |
policy, value = self.model(tensor)
|
73 |
policy = torch.softmax(policy, dim=1).numpy()[0]
|
|
|
|
|
74 |
legal_moves = list(node.board.legal_moves)
|
75 |
-
total_p = 1e-8
|
76 |
for move in legal_moves:
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
for move in legal_moves:
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
88 |
return value.item()
|
89 |
-
|
90 |
def move_to_index(self, move):
|
91 |
from_square = move.from_square
|
92 |
to_square = move.to_square
|
93 |
promotion = move.promotion if move.promotion else 0
|
94 |
-
|
|
|
|
|
95 |
if promotion:
|
96 |
-
|
97 |
-
|
|
|
|
|
|
43 |
for n in reversed(search_path):
|
44 |
n.N += 1
|
45 |
n.W += value
|
46 |
+
n.Q = n.W / n.N if n.N > 0 else 0.0
|
47 |
value = -value # Switch perspective
|
48 |
# Temperature-based sampling for opening diversity
|
49 |
if temperature and temperature > 0:
|
|
|
50 |
moves = list(root.children.keys())
|
51 |
visits = np.array([root.children[m].N for m in moves], dtype=np.float32)
|
52 |
probs = visits ** (1.0 / temperature)
|
|
|
70 |
with torch.no_grad():
|
71 |
policy, value = self.model(tensor)
|
72 |
policy = torch.softmax(policy, dim=1).numpy()[0]
|
73 |
+
assert len(policy) == 4672, f"Policy size mismatch: expected 4672, got {len(policy)}"
|
74 |
+
|
75 |
legal_moves = list(node.board.legal_moves)
|
76 |
+
total_p = 1e-8 # Small epsilon to prevent division by zero
|
77 |
for move in legal_moves:
|
78 |
+
try:
|
79 |
+
idx = self.move_to_index(move)
|
80 |
+
if 0 <= idx < 4672: # Ensure index is within bounds
|
81 |
+
p = policy[idx]
|
82 |
+
total_p += p
|
83 |
+
except Exception:
|
84 |
+
continue # Skip moves that can't be indexed properly
|
85 |
+
|
86 |
+
if total_p < 1e-8: # If all probabilities are extremely small
|
87 |
+
total_p = 1.0 # Fall back to uniform distribution
|
88 |
+
# Use uniform distribution only for legal moves
|
89 |
+
for move in legal_moves:
|
90 |
+
idx = self.move_to_index(move)
|
91 |
+
if 0 <= idx < 4672:
|
92 |
+
policy[idx] = 1.0 / len(legal_moves)
|
93 |
+
|
94 |
+
# Create child nodes only for valid moves
|
95 |
for move in legal_moves:
|
96 |
+
try:
|
97 |
+
idx = self.move_to_index(move)
|
98 |
+
if 0 <= idx < 4672:
|
99 |
+
p = policy[idx] / total_p
|
100 |
+
child_board = node.board.copy()
|
101 |
+
child_board.push(move)
|
102 |
+
child = MCTSNode(child_board, parent=node, move=move)
|
103 |
+
child.P = p
|
104 |
+
node.children[move] = child
|
105 |
+
except Exception:
|
106 |
+
continue # Skip problematic moves
|
107 |
+
|
108 |
return value.item()
|
|
|
109 |
def move_to_index(self, move):
|
110 |
from_square = move.from_square
|
111 |
to_square = move.to_square
|
112 |
promotion = move.promotion if move.promotion else 0
|
113 |
+
# Base index for normal moves
|
114 |
+
idx = from_square * 64 + to_square
|
115 |
+
# Handle promotions (knight=1, bishop=2, rook=3, queen=4)
|
116 |
if promotion:
|
117 |
+
# Map to indices after normal moves (4096 onwards)
|
118 |
+
idx = 4096 + ((promotion - 1) * 64 * 64 // 4) + (from_square * 8 + to_square // 8)
|
119 |
+
# Ensure index is within bounds (4672 = 64*64 + 64*8)
|
120 |
+
return min(idx, 4671)
|
src/o2_agent.py
CHANGED
@@ -1,88 +1,86 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import chess
|
3 |
-
import torch
|
4 |
-
from o2_model import O2Net, board_to_tensor
|
5 |
-
from mcts import MCTS
|
6 |
-
import random
|
7 |
-
|
8 |
-
# Optional: Endgame tablebase and opening book integration placeholders
|
9 |
-
# You can use python-chess's tablebase and opening book modules if desired
|
10 |
-
# Example for endgame tablebase:
|
11 |
-
# from chess import tablebase
|
12 |
-
# tb = tablebase.Tablebase()
|
13 |
-
# tb.add_tablebase('/path/to/syzygy')
|
14 |
-
# if tb.probe_wdl(board) is not None:
|
15 |
-
# # Use tablebase move
|
16 |
-
# Example for opening book:
|
17 |
-
# from chess.polyglot import open_reader
|
18 |
-
# with open_reader('book.bin') as reader:
|
19 |
-
# entry = reader.find(board)
|
20 |
-
# move = entry.move
|
21 |
-
|
22 |
-
class O2Agent:
|
23 |
-
def __init__(self, model_path=None):
|
24 |
-
self.model = O2Net()
|
25 |
-
if model_path:
|
26 |
-
self.model.load_state_dict(torch.load(model_path))
|
27 |
-
self.model.eval()
|
28 |
-
|
29 |
-
def select_move(self, board, use_mcts=True, simulations=100, temperature=0.0):
|
30 |
-
if use_mcts:
|
31 |
-
mcts = MCTS(self.model, simulations=simulations)
|
32 |
-
return mcts.run(board, temperature=temperature)
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
idx
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
move
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
move = agent.select_move(board)
|
88 |
-
print("O2 selects:", move)
|
|
|
1 |
+
import numpy as np
|
2 |
+
import chess
|
3 |
+
import torch
|
4 |
+
from o2_model import O2Net, board_to_tensor
|
5 |
+
from mcts import MCTS
|
6 |
+
import random
|
7 |
+
|
8 |
+
# Optional: Endgame tablebase and opening book integration placeholders
|
9 |
+
# You can use python-chess's tablebase and opening book modules if desired
|
10 |
+
# Example for endgame tablebase:
|
11 |
+
# from chess import tablebase
|
12 |
+
# tb = tablebase.Tablebase()
|
13 |
+
# tb.add_tablebase('/path/to/syzygy')
|
14 |
+
# if tb.probe_wdl(board) is not None:
|
15 |
+
# # Use tablebase move
|
16 |
+
# Example for opening book:
|
17 |
+
# from chess.polyglot import open_reader
|
18 |
+
# with open_reader('book.bin') as reader:
|
19 |
+
# entry = reader.find(board)
|
20 |
+
# move = entry.move
|
21 |
+
|
22 |
+
class O2Agent:
|
23 |
+
def __init__(self, model_path=None):
|
24 |
+
self.model = O2Net()
|
25 |
+
if model_path:
|
26 |
+
self.model.load_state_dict(torch.load(model_path))
|
27 |
+
self.model.eval()
|
28 |
+
|
29 |
+
def select_move(self, board, use_mcts=True, simulations=100, temperature=0.0):
|
30 |
+
if use_mcts:
|
31 |
+
mcts = MCTS(self.model, simulations=simulations)
|
32 |
+
return mcts.run(board, temperature=temperature)
|
33 |
+
tensor = torch.tensor(board_to_tensor(board)).unsqueeze(0)
|
34 |
+
with torch.no_grad():
|
35 |
+
policy, _ = self.model(tensor)
|
36 |
+
legal_moves = list(board.legal_moves)
|
37 |
+
move_scores = []
|
38 |
+
for move in legal_moves:
|
39 |
+
move_idx = self.move_to_index(move)
|
40 |
+
move_scores.append(policy[0, move_idx].item())
|
41 |
+
if temperature and temperature > 0:
|
42 |
+
# Softmax sampling
|
43 |
+
scores = np.array(move_scores)
|
44 |
+
exp_scores = np.exp(scores / temperature)
|
45 |
+
probs = exp_scores / np.sum(exp_scores)
|
46 |
+
move = np.random.choice(legal_moves, p=probs)
|
47 |
+
return move
|
48 |
+
best_move = legal_moves[int(torch.tensor(move_scores).argmax())]
|
49 |
+
return best_move
|
50 |
+
|
51 |
+
def move_to_index(self, move):
|
52 |
+
# Encode move as from_square * 64 + to_square + promotion_offset
|
53 |
+
from_square = move.from_square
|
54 |
+
to_square = move.to_square
|
55 |
+
promotion = move.promotion if move.promotion else 0
|
56 |
+
promotion_offset = 0
|
57 |
+
if promotion:
|
58 |
+
# Promotion: 1=Knight, 2=Bishop, 3=Rook, 4=Queen (python-chess)
|
59 |
+
# Offset: 4096 + (promotion-1)*64*64//4
|
60 |
+
promotion_offset = 4096 + (promotion - 1) * 256
|
61 |
+
idx = from_square * 64 + to_square + promotion_offset
|
62 |
+
# Ensure index is within bounds
|
63 |
+
return idx if idx < 4672 else idx % 4672
|
64 |
+
|
65 |
+
def index_to_move(self, board, index):
|
66 |
+
# Decode index to move (reverse of move_to_index)
|
67 |
+
if index >= 4096:
|
68 |
+
promotion = (index - 4096) % 4 + 1
|
69 |
+
idx = index - 4096
|
70 |
+
from_square = idx // 64
|
71 |
+
to_square = idx % 64
|
72 |
+
move = chess.Move(from_square, to_square, promotion=promotion)
|
73 |
+
else:
|
74 |
+
from_square = index // 64
|
75 |
+
to_square = index % 64
|
76 |
+
move = chess.Move(from_square, to_square)
|
77 |
+
if move in board.legal_moves:
|
78 |
+
return move
|
79 |
+
# Fallback: pick a random legal move
|
80 |
+
return random.choice(list(board.legal_moves))
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
board = chess.Board()
|
84 |
+
agent = O2Agent()
|
85 |
+
move = agent.select_move(board)
|
86 |
+
print("O2 selects:", move)
|
|
|
|