FlameF0X commited on
Commit
bcc23fe
·
verified ·
1 Parent(s): 6112849

Upload 4 files

Browse files
Files changed (3) hide show
  1. src/app.py +277 -0
  2. src/mcts.py +40 -17
  3. src/o2_agent.py +86 -88
src/app.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ st.set_page_config(page_title="Play Chess vs o2", layout="centered")
3
+
4
+ import chess
5
+ import chess.svg
6
+ import torch
7
+ from o2_model import O2Net, board_to_tensor
8
+ from o2_agent import O2Agent
9
+ from PIL import Image
10
+ import io
11
+ import base64
12
+ import os
13
+ import chess.pgn
14
+ import requests
15
+ import random
16
+ import re
17
+ import tempfile
18
+
19
+ # Use temp directory for cache/model paths
20
+ MODEL_CACHE_DIR = tempfile.gettempdir()
21
+ MODEL_REPO = "FlameF0X/o2"
22
+ MODEL_FILENAME = "o2_agent.pth"
23
+
24
+ def ensure_model():
25
+ from huggingface_hub import hf_hub_download
26
+ # Always download to cache and return the path
27
+ return hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME, cache_dir=MODEL_CACHE_DIR)
28
+
29
+ # Load model with better error handling
30
+ @st.cache_resource
31
+ def load_agent():
32
+ try:
33
+ model_path = ensure_model()
34
+ if model_path is None:
35
+ return None
36
+
37
+ agent = O2Agent()
38
+ agent.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
39
+ agent.model.eval()
40
+ return agent
41
+ except Exception as e:
42
+ st.error(f"Error loading agent: {e}")
43
+ return None
44
+
45
+ def render_svg(svg):
46
+ b64 = base64.b64encode(svg.encode('utf-8')).decode('utf-8')
47
+ return f"<img src='data:image/svg+xml;base64,{b64}'/>", b64
48
+
49
+ # --- Move parsing utility ---
50
+ def parse_move_input(move_input, board):
51
+ if not move_input:
52
+ return None
53
+ move_input = move_input.strip()
54
+ if len(move_input) >= 4 and move_input[2:4].isalnum():
55
+ try:
56
+ move = chess.Move.from_uci(move_input.lower())
57
+ if move in board.legal_moves:
58
+ return move
59
+ except:
60
+ pass
61
+ try:
62
+ move = board.parse_san(move_input)
63
+ if move in board.legal_moves:
64
+ return move
65
+ except:
66
+ pass
67
+ try:
68
+ variations = [move_input.upper(), move_input.lower(), move_input.capitalize()]
69
+ for variation in variations:
70
+ try:
71
+ move = board.parse_san(variation)
72
+ if move in board.legal_moves:
73
+ return move
74
+ except:
75
+ continue
76
+ except:
77
+ pass
78
+ if len(move_input) == 2 and move_input[0].lower() in 'abcdefgh' and move_input[1] in '12345678':
79
+ try:
80
+ move = board.parse_san(move_input.lower())
81
+ if move in board.legal_moves:
82
+ return move
83
+ except:
84
+ pass
85
+ castling_variations = {
86
+ '0-0': 'O-O', '0-0-0': 'O-O-O', 'oo': 'O-O', 'ooo': 'O-O-O', 'o-o': 'O-O', 'o-o-o': 'O-O-O',
87
+ }
88
+ lower_input = move_input.lower()
89
+ if lower_input in castling_variations:
90
+ try:
91
+ move = board.parse_san(castling_variations[lower_input])
92
+ if move in board.legal_moves:
93
+ return move
94
+ except:
95
+ pass
96
+ return None
97
+
98
+ # --- Main UI ---
99
+ if "board" not in st.session_state:
100
+ st.session_state.board = chess.Board()
101
+ if "history" not in st.session_state:
102
+ st.session_state.history = []
103
+
104
+ # Load agent with error handling
105
+ agent = None
106
+ agent_loaded = False
107
+
108
+ with st.spinner("Loading o2 model..."):
109
+ try:
110
+ agent = load_agent()
111
+ if agent is not None:
112
+ agent_loaded = True
113
+ st.success("o2 model loaded successfully!")
114
+ else:
115
+ st.warning("Failed to load o2 model. Using random moves for AI.")
116
+ except Exception as e:
117
+ st.error(f"Failed to load o2: {e}")
118
+ st.warning("Using random moves for AI.")
119
+
120
+ board = st.session_state.board
121
+ history = st.session_state.history
122
+
123
+ st.title("♟️ Play Chess vs o2")
124
+
125
+ if not agent_loaded:
126
+ st.info("🎲 AI is using random moves (o2 model not available)")
127
+
128
+ if st.button("Reset Game"):
129
+ st.session_state.board = chess.Board()
130
+ st.session_state.history = []
131
+ st.rerun()
132
+
133
+ # Create two columns for layout
134
+ col_board, col_pgn = st.columns([2, 1])
135
+
136
+ with col_board:
137
+ board_placeholder = st.empty()
138
+
139
+ def render_board():
140
+ try:
141
+ last_move = board.peek() if board.move_stack else None
142
+ svg_board = chess.svg.board(board=board, lastmove=last_move, size=400)
143
+ board_placeholder.markdown(f'<div style="display: flex; justify-content: center;">{svg_board}</div>', unsafe_allow_html=True)
144
+ except Exception as e:
145
+ st.error(f"Error rendering board: {e}")
146
+
147
+ render_board()
148
+
149
+ col1, col2 = st.columns(2)
150
+ with col1:
151
+ st.write(f"**Turn:** {'White' if board.turn == chess.WHITE else 'Black'}")
152
+ with col2:
153
+ if board.is_check():
154
+ st.write("**Check!**")
155
+
156
+ with col_pgn:
157
+ st.write("### Game History")
158
+ pgn_placeholder = st.empty()
159
+
160
+ def render_pgn():
161
+ if history:
162
+ try:
163
+ game = chess.pgn.Game()
164
+ game.headers["Event"] = "Human vs o2"
165
+ game.headers["White"] = "Human"
166
+ game.headers["Black"] = "o2" if agent_loaded else "Random AI"
167
+ node = game
168
+ temp_board = chess.Board()
169
+ for uci in history:
170
+ move = chess.Move.from_uci(uci)
171
+ if move in temp_board.legal_moves:
172
+ node = node.add_main_variation(move)
173
+ temp_board.push(move)
174
+ else:
175
+ break
176
+ pgn_placeholder.code(str(game), language="pgn")
177
+ except Exception as e:
178
+ move_pairs = []
179
+ for i in range(0, len(history), 2):
180
+ white_move = history[i]
181
+ black_move = history[i+1] if i+1 < len(history) else ""
182
+ move_pairs.append(f"{i//2 + 1}. {white_move} {black_move}")
183
+ pgn_placeholder.code("\n".join(move_pairs))
184
+ else:
185
+ pgn_placeholder.text("No moves yet")
186
+
187
+ render_pgn()
188
+
189
+ if not board.is_game_over() and board.turn == chess.WHITE:
190
+ st.write("### Your Turn (White)")
191
+ legal_moves = list(board.legal_moves)
192
+ legal_moves_uci = [move.uci() for move in legal_moves]
193
+ legal_moves_san = []
194
+ for move in legal_moves:
195
+ try:
196
+ san = board.san(move)
197
+ legal_moves_san.append(san)
198
+ except:
199
+ legal_moves_san.append(move.uci())
200
+ with st.expander("Show legal moves"):
201
+ col1, col2 = st.columns(2)
202
+ with col1:
203
+ st.write("**Algebraic notation:**")
204
+ st.write(", ".join(sorted(legal_moves_san)))
205
+ with col2:
206
+ st.write("**UCI notation:**")
207
+ st.write(", ".join(sorted(legal_moves_uci)))
208
+ user_move = st.text_input("Enter your move (e.g., E4, Nf3, e2e4, O-O):", key="move_input", help="You can use algebraic notation (E4, Nf3) or UCI notation (e2e4). Case doesn't matter!")
209
+ col1, col2 = st.columns(2)
210
+ with col1:
211
+ if st.button("Submit Move"):
212
+ if user_move:
213
+ parsed_move = parse_move_input(user_move, board)
214
+ if parsed_move:
215
+ try:
216
+ board.push(parsed_move)
217
+ history.append(parsed_move.uci())
218
+ render_board() # Update board immediately
219
+ render_pgn() # Update PGN immediately
220
+ st.success(f"You played: {board.san(parsed_move)} ({parsed_move.uci()})")
221
+ st.rerun()
222
+ except Exception as e:
223
+ st.error(f"Error making move: {e}")
224
+ else:
225
+ st.warning(f"Invalid move: '{user_move}'. Please check the legal moves above.")
226
+ else:
227
+ st.warning("Please enter a move.")
228
+ with col2:
229
+ if st.button("Random Move"):
230
+ if legal_moves:
231
+ random_move = random.choice(legal_moves)
232
+ board.push(random_move)
233
+ history.append(random_move.uci())
234
+ render_board() # Update board immediately
235
+ render_pgn() # Update PGN immediately
236
+ st.rerun()
237
+
238
+ if not board.is_game_over() and board.turn == chess.BLACK:
239
+ st.write("### o2's Turn (Black)")
240
+ with st.spinner("o2 is thinking..."):
241
+ try:
242
+ if agent_loaded and agent:
243
+ # Use temperature sampling for first 10 moves, then greedy
244
+ if len(history) < 20: # 10 moves per side
245
+ move = agent.select_move(board, use_mcts=True, simulations=30, temperature=1.2)
246
+ else:
247
+ move = agent.select_move(board, use_mcts=True, simulations=30, temperature=0.0)
248
+ else:
249
+ legal_moves = list(board.legal_moves)
250
+ move = random.choice(legal_moves) if legal_moves else None
251
+ if move and move in board.legal_moves:
252
+ move_san = board.san(move)
253
+ board.push(move)
254
+ history.append(move.uci())
255
+ st.success(f"o2 played: {move_san} ({move.uci()})")
256
+ st.rerun()
257
+ else:
258
+ st.error("o2 couldn't find a valid move")
259
+ except Exception as e:
260
+ st.error(f"Error during o2 move: {e}")
261
+
262
+ if board.is_game_over():
263
+ st.write("### Game Over!")
264
+ result = board.result()
265
+ outcome = board.outcome()
266
+ if result == "1-0":
267
+ st.success("White wins!")
268
+ elif result == "0-1":
269
+ st.error("Black wins!")
270
+ else:
271
+ st.info("Draw!")
272
+ st.write(f"**Result:** {result}")
273
+ st.write(f"**Termination:** {outcome.termination.name}")
274
+ if st.button("Start New Game"):
275
+ st.session_state.board = chess.Board()
276
+ st.session_state.history = []
277
+ st.rerun()
src/mcts.py CHANGED
@@ -43,11 +43,10 @@ class MCTS:
43
  for n in reversed(search_path):
44
  n.N += 1
45
  n.W += value
46
- n.Q = n.W / n.N
47
  value = -value # Switch perspective
48
  # Temperature-based sampling for opening diversity
49
  if temperature and temperature > 0:
50
- import numpy as np
51
  moves = list(root.children.keys())
52
  visits = np.array([root.children[m].N for m in moves], dtype=np.float32)
53
  probs = visits ** (1.0 / temperature)
@@ -71,27 +70,51 @@ class MCTS:
71
  with torch.no_grad():
72
  policy, value = self.model(tensor)
73
  policy = torch.softmax(policy, dim=1).numpy()[0]
 
 
74
  legal_moves = list(node.board.legal_moves)
75
- total_p = 1e-8
76
  for move in legal_moves:
77
- idx = self.move_to_index(move)
78
- p = policy[idx]
79
- total_p += p
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  for move in legal_moves:
81
- idx = self.move_to_index(move)
82
- p = policy[idx] / total_p
83
- child_board = node.board.copy()
84
- child_board.push(move)
85
- child = MCTSNode(child_board, parent=node, move=move)
86
- child.P = p
87
- node.children[move] = child
 
 
 
 
 
88
  return value.item()
89
-
90
  def move_to_index(self, move):
91
  from_square = move.from_square
92
  to_square = move.to_square
93
  promotion = move.promotion if move.promotion else 0
94
- promotion_offset = 0
 
 
95
  if promotion:
96
- promotion_offset = 4096 + (promotion - 1)
97
- return from_square * 64 + to_square + promotion_offset
 
 
 
43
  for n in reversed(search_path):
44
  n.N += 1
45
  n.W += value
46
+ n.Q = n.W / n.N if n.N > 0 else 0.0
47
  value = -value # Switch perspective
48
  # Temperature-based sampling for opening diversity
49
  if temperature and temperature > 0:
 
50
  moves = list(root.children.keys())
51
  visits = np.array([root.children[m].N for m in moves], dtype=np.float32)
52
  probs = visits ** (1.0 / temperature)
 
70
  with torch.no_grad():
71
  policy, value = self.model(tensor)
72
  policy = torch.softmax(policy, dim=1).numpy()[0]
73
+ assert len(policy) == 4672, f"Policy size mismatch: expected 4672, got {len(policy)}"
74
+
75
  legal_moves = list(node.board.legal_moves)
76
+ total_p = 1e-8 # Small epsilon to prevent division by zero
77
  for move in legal_moves:
78
+ try:
79
+ idx = self.move_to_index(move)
80
+ if 0 <= idx < 4672: # Ensure index is within bounds
81
+ p = policy[idx]
82
+ total_p += p
83
+ except Exception:
84
+ continue # Skip moves that can't be indexed properly
85
+
86
+ if total_p < 1e-8: # If all probabilities are extremely small
87
+ total_p = 1.0 # Fall back to uniform distribution
88
+ # Use uniform distribution only for legal moves
89
+ for move in legal_moves:
90
+ idx = self.move_to_index(move)
91
+ if 0 <= idx < 4672:
92
+ policy[idx] = 1.0 / len(legal_moves)
93
+
94
+ # Create child nodes only for valid moves
95
  for move in legal_moves:
96
+ try:
97
+ idx = self.move_to_index(move)
98
+ if 0 <= idx < 4672:
99
+ p = policy[idx] / total_p
100
+ child_board = node.board.copy()
101
+ child_board.push(move)
102
+ child = MCTSNode(child_board, parent=node, move=move)
103
+ child.P = p
104
+ node.children[move] = child
105
+ except Exception:
106
+ continue # Skip problematic moves
107
+
108
  return value.item()
 
109
  def move_to_index(self, move):
110
  from_square = move.from_square
111
  to_square = move.to_square
112
  promotion = move.promotion if move.promotion else 0
113
+ # Base index for normal moves
114
+ idx = from_square * 64 + to_square
115
+ # Handle promotions (knight=1, bishop=2, rook=3, queen=4)
116
  if promotion:
117
+ # Map to indices after normal moves (4096 onwards)
118
+ idx = 4096 + ((promotion - 1) * 64 * 64 // 4) + (from_square * 8 + to_square // 8)
119
+ # Ensure index is within bounds (4672 = 64*64 + 64*8)
120
+ return min(idx, 4671)
src/o2_agent.py CHANGED
@@ -1,88 +1,86 @@
1
- import numpy as np
2
- import chess
3
- import torch
4
- from o2_model import O2Net, board_to_tensor
5
- from mcts import MCTS
6
- import random
7
-
8
- # Optional: Endgame tablebase and opening book integration placeholders
9
- # You can use python-chess's tablebase and opening book modules if desired
10
- # Example for endgame tablebase:
11
- # from chess import tablebase
12
- # tb = tablebase.Tablebase()
13
- # tb.add_tablebase('/path/to/syzygy')
14
- # if tb.probe_wdl(board) is not None:
15
- # # Use tablebase move
16
- # Example for opening book:
17
- # from chess.polyglot import open_reader
18
- # with open_reader('book.bin') as reader:
19
- # entry = reader.find(board)
20
- # move = entry.move
21
-
22
- class O2Agent:
23
- def __init__(self, model_path=None):
24
- self.model = O2Net()
25
- if model_path:
26
- self.model.load_state_dict(torch.load(model_path))
27
- self.model.eval()
28
-
29
- def select_move(self, board, use_mcts=True, simulations=100, temperature=0.0):
30
- if use_mcts:
31
- mcts = MCTS(self.model, simulations=simulations)
32
- return mcts.run(board, temperature=temperature)
33
- # SAFEGUARD IMPORT (add this line)
34
- import numpy as np
35
- tensor = torch.tensor(board_to_tensor(board)).unsqueeze(0)
36
- with torch.no_grad():
37
- policy, _ = self.model(tensor)
38
- legal_moves = list(board.legal_moves)
39
- move_scores = []
40
- for move in legal_moves:
41
- move_idx = self.move_to_index(move)
42
- move_scores.append(policy[0, move_idx].item())
43
- if temperature and temperature > 0:
44
- # Softmax sampling
45
- scores = np.array(move_scores)
46
- exp_scores = np.exp(scores / temperature)
47
- probs = exp_scores / np.sum(exp_scores)
48
- move = np.random.choice(legal_moves, p=probs)
49
- return move
50
- best_move = legal_moves[int(torch.tensor(move_scores).argmax())]
51
- return best_move
52
-
53
- def move_to_index(self, move):
54
- # Encode move as from_square * 64 + to_square + promotion_offset
55
- from_square = move.from_square
56
- to_square = move.to_square
57
- promotion = move.promotion if move.promotion else 0
58
- promotion_offset = 0
59
- if promotion:
60
- # Promotion: 1=Knight, 2=Bishop, 3=Rook, 4=Queen (python-chess)
61
- # Offset: 4096 + (promotion-1)*64*64//4
62
- promotion_offset = 4096 + (promotion - 1) * 256
63
- idx = from_square * 64 + to_square + promotion_offset
64
- # Ensure index is within bounds
65
- return idx if idx < 4672 else idx % 4672
66
-
67
- def index_to_move(self, board, index):
68
- # Decode index to move (reverse of move_to_index)
69
- if index >= 4096:
70
- promotion = (index - 4096) % 4 + 1
71
- idx = index - 4096
72
- from_square = idx // 64
73
- to_square = idx % 64
74
- move = chess.Move(from_square, to_square, promotion=promotion)
75
- else:
76
- from_square = index // 64
77
- to_square = index % 64
78
- move = chess.Move(from_square, to_square)
79
- if move in board.legal_moves:
80
- return move
81
- # Fallback: pick a random legal move
82
- return random.choice(list(board.legal_moves))
83
-
84
- if __name__ == "__main__":
85
- board = chess.Board()
86
- agent = O2Agent()
87
- move = agent.select_move(board)
88
- print("O2 selects:", move)
 
1
+ import numpy as np
2
+ import chess
3
+ import torch
4
+ from o2_model import O2Net, board_to_tensor
5
+ from mcts import MCTS
6
+ import random
7
+
8
+ # Optional: Endgame tablebase and opening book integration placeholders
9
+ # You can use python-chess's tablebase and opening book modules if desired
10
+ # Example for endgame tablebase:
11
+ # from chess import tablebase
12
+ # tb = tablebase.Tablebase()
13
+ # tb.add_tablebase('/path/to/syzygy')
14
+ # if tb.probe_wdl(board) is not None:
15
+ # # Use tablebase move
16
+ # Example for opening book:
17
+ # from chess.polyglot import open_reader
18
+ # with open_reader('book.bin') as reader:
19
+ # entry = reader.find(board)
20
+ # move = entry.move
21
+
22
+ class O2Agent:
23
+ def __init__(self, model_path=None):
24
+ self.model = O2Net()
25
+ if model_path:
26
+ self.model.load_state_dict(torch.load(model_path))
27
+ self.model.eval()
28
+
29
+ def select_move(self, board, use_mcts=True, simulations=100, temperature=0.0):
30
+ if use_mcts:
31
+ mcts = MCTS(self.model, simulations=simulations)
32
+ return mcts.run(board, temperature=temperature)
33
+ tensor = torch.tensor(board_to_tensor(board)).unsqueeze(0)
34
+ with torch.no_grad():
35
+ policy, _ = self.model(tensor)
36
+ legal_moves = list(board.legal_moves)
37
+ move_scores = []
38
+ for move in legal_moves:
39
+ move_idx = self.move_to_index(move)
40
+ move_scores.append(policy[0, move_idx].item())
41
+ if temperature and temperature > 0:
42
+ # Softmax sampling
43
+ scores = np.array(move_scores)
44
+ exp_scores = np.exp(scores / temperature)
45
+ probs = exp_scores / np.sum(exp_scores)
46
+ move = np.random.choice(legal_moves, p=probs)
47
+ return move
48
+ best_move = legal_moves[int(torch.tensor(move_scores).argmax())]
49
+ return best_move
50
+
51
+ def move_to_index(self, move):
52
+ # Encode move as from_square * 64 + to_square + promotion_offset
53
+ from_square = move.from_square
54
+ to_square = move.to_square
55
+ promotion = move.promotion if move.promotion else 0
56
+ promotion_offset = 0
57
+ if promotion:
58
+ # Promotion: 1=Knight, 2=Bishop, 3=Rook, 4=Queen (python-chess)
59
+ # Offset: 4096 + (promotion-1)*64*64//4
60
+ promotion_offset = 4096 + (promotion - 1) * 256
61
+ idx = from_square * 64 + to_square + promotion_offset
62
+ # Ensure index is within bounds
63
+ return idx if idx < 4672 else idx % 4672
64
+
65
+ def index_to_move(self, board, index):
66
+ # Decode index to move (reverse of move_to_index)
67
+ if index >= 4096:
68
+ promotion = (index - 4096) % 4 + 1
69
+ idx = index - 4096
70
+ from_square = idx // 64
71
+ to_square = idx % 64
72
+ move = chess.Move(from_square, to_square, promotion=promotion)
73
+ else:
74
+ from_square = index // 64
75
+ to_square = index % 64
76
+ move = chess.Move(from_square, to_square)
77
+ if move in board.legal_moves:
78
+ return move
79
+ # Fallback: pick a random legal move
80
+ return random.choice(list(board.legal_moves))
81
+
82
+ if __name__ == "__main__":
83
+ board = chess.Board()
84
+ agent = O2Agent()
85
+ move = agent.select_move(board)
86
+ print("O2 selects:", move)