FlameF0X commited on
Commit
8806ce1
·
verified ·
1 Parent(s): 8a01c41

Upload 11 files

Browse files
src/o1/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # o1 package
src/o1/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (146 Bytes). View file
 
src/o1/__pycache__/agent.cpython-312.pyc ADDED
Binary file (9.63 kB). View file
 
src/o1/__pycache__/mcts.cpython-312.pyc ADDED
Binary file (5.4 kB). View file
 
src/o1/__pycache__/train.cpython-312.pyc ADDED
Binary file (8.77 kB). View file
 
src/o1/__pycache__/utils.cpython-312.pyc ADDED
Binary file (7 kB). View file
 
src/o1/agent.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import chess
6
+
7
+ class SEBlock(nn.Module):
8
+ def __init__(self, channels, reduction=16):
9
+ super().__init__()
10
+ self.fc1 = nn.Linear(channels, channels // reduction)
11
+ self.fc2 = nn.Linear(channels // reduction, channels)
12
+
13
+ def forward(self, x):
14
+ b, c, h, w = x.size()
15
+ y = x.view(b, c, -1).mean(dim=2)
16
+ y = F.relu(self.fc1(y))
17
+ y = torch.sigmoid(self.fc2(y))
18
+ y = y.view(b, c, 1, 1)
19
+ return x * y
20
+
21
+ class ResidualBlock(nn.Module):
22
+ def __init__(self, channels, dropout=0.2):
23
+ super().__init__()
24
+ self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
25
+ self.bn1 = nn.BatchNorm2d(channels)
26
+ self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
27
+ self.bn2 = nn.BatchNorm2d(channels)
28
+ self.se = SEBlock(channels)
29
+ self.dropout = nn.Dropout2d(dropout)
30
+
31
+ def forward(self, x):
32
+ residual = x
33
+ out = F.relu(self.bn1(self.conv1(x)))
34
+ out = self.bn2(self.conv2(out))
35
+ out = self.se(out)
36
+ out = self.dropout(out)
37
+ out += residual
38
+ return F.relu(out)
39
+
40
+ class ChessNet(nn.Module):
41
+ def __init__(self, input_channels=17, board_size=8, policy_size=4672, num_blocks=20):
42
+ super().__init__()
43
+ self.conv_in = nn.Conv2d(input_channels, 256, kernel_size=3, padding=1)
44
+ self.bn_in = nn.BatchNorm2d(256)
45
+ self.res_blocks = nn.Sequential(*[ResidualBlock(256) for _ in range(num_blocks)])
46
+ self.fc1 = nn.Linear(256 * board_size * board_size, 512)
47
+ self.ln_fc1 = nn.LayerNorm(512)
48
+ # Policy head
49
+ self.policy_head1 = nn.Linear(512, 256)
50
+ self.policy_head2 = nn.Linear(256, policy_size)
51
+ # Value head
52
+ self.value_head1 = nn.Linear(512, 128)
53
+ self.value_head2 = nn.Linear(128, 1)
54
+
55
+ def forward(self, x):
56
+ x = F.relu(self.bn_in(self.conv_in(x)))
57
+ x = self.res_blocks(x)
58
+ x = x.view(x.size(0), -1)
59
+ x = F.relu(self.ln_fc1(self.fc1(x)))
60
+ # Policy head
61
+ policy = F.relu(self.policy_head1(x))
62
+ policy = self.policy_head2(policy)
63
+ # Value head
64
+ value = F.relu(self.value_head1(x))
65
+ value = torch.tanh(self.value_head2(value))
66
+ return policy, value
67
+
68
+ class Agent:
69
+ def __init__(self, device='cpu'):
70
+ self.device = device
71
+ self.model = ChessNet().to(device)
72
+ self.model.eval()
73
+
74
+ def board_to_tensor(self, board):
75
+ # 12x8x8 binary planes for piece types/colors
76
+ piece_map = board.piece_map()
77
+ tensor = np.zeros((17, 8, 8), dtype=np.float32)
78
+ for square, piece in piece_map.items():
79
+ idx = self.piece_to_index(piece)
80
+ row, col = divmod(square, 8)
81
+ tensor[idx, row, col] = 1
82
+ # Add castling rights (4 planes)
83
+ if board.has_kingside_castling_rights(chess.WHITE):
84
+ tensor[12, :, :] = 1
85
+ if board.has_queenside_castling_rights(chess.WHITE):
86
+ tensor[13, :, :] = 1
87
+ if board.has_kingside_castling_rights(chess.BLACK):
88
+ tensor[14, :, :] = 1
89
+ if board.has_queenside_castling_rights(chess.BLACK):
90
+ tensor[15, :, :] = 1
91
+ # Add move count (normalized, 1 plane)
92
+ tensor[16, :, :] = board.fullmove_number / 100.0
93
+ # Optionally, add repetition or other features here
94
+ return torch.tensor(tensor, device=self.device).unsqueeze(0)
95
+
96
+ def piece_to_index(self, piece):
97
+ # 0-5: white P,N,B,R,Q,K; 6-11: black P,N,B,R,Q,K
98
+ offset = 0 if piece.color == chess.WHITE else 6
99
+ piece_type_map = {
100
+ chess.PAWN: 0,
101
+ chess.KNIGHT: 1,
102
+ chess.BISHOP: 2,
103
+ chess.ROOK: 3,
104
+ chess.QUEEN: 4,
105
+ chess.KING: 5
106
+ }
107
+ return offset + piece_type_map[piece.piece_type]
108
+
109
+ def predict(self, board):
110
+ x = self.board_to_tensor(board)
111
+ with torch.no_grad():
112
+ policy_logits, value = self.model(x)
113
+ return policy_logits, value
114
+
115
+ def diffusion_sample(self, policy_logits, steps=10, noise_scale=1.0):
116
+ """
117
+ Apply a simple diffusion process to the policy logits.
118
+ At each step, add Gaussian noise and denoise by averaging with the original logits.
119
+ """
120
+ x = policy_logits.clone()
121
+ orig = policy_logits.clone()
122
+ for _ in range(steps):
123
+ noise = torch.randn_like(x) * noise_scale
124
+ x = x + noise
125
+ x = (x + orig) / 2 # simple denoising step
126
+ return x
127
+
128
+ def predict_with_diffusion(self, board, steps=10, noise_scale=1.0):
129
+ x = self.board_to_tensor(board)
130
+ with torch.no_grad():
131
+ policy_logits, value = self.model(x)
132
+ diffused_logits = self.diffusion_sample(policy_logits, steps=steps, noise_scale=noise_scale)
133
+ return diffused_logits, value
src/o1/mcts.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Monte Carlo Tree Search (MCTS) for o1 agent.
3
+ Basic implementation: runs simulations, selects moves by visit count.
4
+ Integrate with neural net for policy/value guidance for full strength.
5
+ """
6
+ import chess
7
+ import random
8
+ from collections import defaultdict
9
+ import torch
10
+
11
+ class MCTSNode:
12
+ def __init__(self, board, parent=None, move=None):
13
+ self.board = board.copy()
14
+ self.parent = parent
15
+ self.move = move
16
+ self.children = []
17
+ self.visits = 0
18
+ self.value = 0.0
19
+ self.untried_moves = list(board.legal_moves)
20
+
21
+ def is_fully_expanded(self):
22
+ return len(self.untried_moves) == 0
23
+
24
+ def best_child(self, c_param=1.4):
25
+ choices = [
26
+ (child.value / (child.visits + 1e-6) + c_param * ( (2 * (self.visits + 1e-6)) ** 0.5 / (child.visits + 1e-6) ), child)
27
+ for child in self.children
28
+ ]
29
+ return max(choices, key=lambda x: x[0])[1]
30
+
31
+ class MCTS:
32
+ def __init__(self, agent=None, simulations=50):
33
+ self.agent = agent
34
+ self.simulations = simulations
35
+
36
+ def search(self, board, restrict_top_n=None):
37
+ root = MCTSNode(board)
38
+ for _ in range(self.simulations):
39
+ node = root
40
+ sim_board = board.copy()
41
+ # Selection
42
+ while node.is_fully_expanded() and node.children:
43
+ node = node.best_child()
44
+ sim_board.push(node.move)
45
+ # Expansion
46
+ if node.untried_moves:
47
+ move = random.choice(node.untried_moves)
48
+ sim_board.push(move)
49
+ child = MCTSNode(sim_board, parent=node, move=move)
50
+ node.children.append(child)
51
+ node.untried_moves.remove(move)
52
+ node = child
53
+ # Simulation
54
+ result = self.simulate(sim_board)
55
+ # Backpropagation
56
+ # If it's black's turn at the node, invert the value for correct perspective
57
+ invert = False
58
+ temp_node = node
59
+ while temp_node.parent is not None:
60
+ temp_node = temp_node.parent
61
+ invert = not invert
62
+ value = -result if invert else result
63
+ while node:
64
+ node.visits += 1
65
+ node.value += value
66
+ node = node.parent
67
+ # Choose move with most visits, but restrict to top-N if specified
68
+ if not root.children:
69
+ return random.choice(list(board.legal_moves))
70
+ children_sorted = sorted(root.children, key=lambda c: c.visits, reverse=True)
71
+ if restrict_top_n is not None and restrict_top_n < len(children_sorted):
72
+ # Only consider top-N moves
73
+ children_sorted = children_sorted[:restrict_top_n]
74
+ best = max(children_sorted, key=lambda c: c.visits)
75
+ return best.move
76
+
77
+ def simulate(self, board, use_diffusion=True, diffusion_steps=10, noise_scale=1.0):
78
+ # Use neural network to evaluate the board instead of random playout
79
+ if self.agent is not None:
80
+ with torch.no_grad():
81
+ if use_diffusion and hasattr(self.agent, 'predict_with_diffusion'):
82
+ _, value = self.agent.predict_with_diffusion(board, steps=diffusion_steps, noise_scale=noise_scale)
83
+ else:
84
+ _, value = self.agent.predict(board)
85
+ return value.item()
86
+ # Fallback: play random moves until game ends
87
+ while not board.is_game_over():
88
+ move = random.choice(list(board.legal_moves))
89
+ board.push(move)
90
+ result = board.result()
91
+ if result == '1-0':
92
+ return 1
93
+ elif result == '0-1':
94
+ return -1
95
+ else:
96
+ return 0
src/o1/selfplay.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Self-play orchestration for o1 agent.
3
+ Runs self-play games using MCTS for move selection.
4
+ """
5
+ import chess
6
+ from o1.mcts import MCTS
7
+
8
+ def run_selfplay(agent, num_games=1, simulations=50):
9
+ """Run self-play games using MCTS and return experience."""
10
+ all_experience = []
11
+ for game_idx in range(num_games):
12
+ board = chess.Board()
13
+ mcts = MCTS(agent, simulations=simulations)
14
+ game_data = []
15
+ while not board.is_game_over():
16
+ move = mcts.search(board)
17
+ state_tensor = agent.board_to_tensor(board)
18
+ # Policy: one-hot for chosen move (for now)
19
+ policy = [0] * 4672 # 4672 is max legal moves in chess
20
+ move_idx = list(board.legal_moves).index(move)
21
+ policy[move_idx] = 1
22
+ value = 0 # Placeholder, will be set after game
23
+ game_data.append((state_tensor, policy, value))
24
+ board.push(move)
25
+ # Assign final result as value for all positions
26
+ result = board.result()
27
+ if result == '1-0':
28
+ z = 5
29
+ elif result == '0-1':
30
+ z = -1
31
+ else:
32
+ z = 0
33
+ game_data = [(s, p, z) for (s, p, v) in game_data]
34
+ all_experience.extend(game_data)
35
+ return all_experience
36
+
37
+ # Self-play loop implementation will go here
src/o1/train.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chess
2
+ import random
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ from o1.agent import Agent
7
+ from o1.mcts import MCTS
8
+ from o1.utils import save_board_svg, save_model
9
+
10
+ class ExperienceBuffer:
11
+ def __init__(self, max_size=10000):
12
+ self.buffer = []
13
+ self.max_size = max_size
14
+ def add(self, experience):
15
+ if len(self.buffer) >= self.max_size:
16
+ self.buffer.pop(0)
17
+ self.buffer.append(experience)
18
+ def sample(self, batch_size):
19
+ return random.sample(self.buffer, min(batch_size, len(self.buffer)))
20
+ def get_tensors(self, batch):
21
+ # Convert batch of (state_tensor, policy, value) to tensors
22
+ # Ensure state tensors are float32 and have correct shape
23
+ states = torch.cat([s.float() for (s, _, _) in batch], dim=0)
24
+ policies = torch.tensor([p for (_, p, _) in batch], dtype=torch.float32)
25
+ values = torch.tensor([v for (_, _, v) in batch], dtype=torch.float32).unsqueeze(1)
26
+ return states, policies, values
27
+
28
+ def self_play_game(agent, simulations=10, save_svg=False, svg_prefix="game", max_moves=40):
29
+ # Randomly choose o1's color for this game
30
+ o1_color = random.choice([chess.WHITE, chess.BLACK])
31
+ board = chess.Board()
32
+ mcts = MCTS(agent, simulations=simulations)
33
+ game_data = []
34
+ move_num = 0
35
+ print(f"o1 is playing as {'White' if o1_color == chess.WHITE else 'Black'}")
36
+
37
+ while not board.is_game_over() and move_num < max_moves:
38
+ # Determine if it's o1's turn
39
+ o1_turn = (board.turn == o1_color)
40
+ if o1_turn:
41
+ move = mcts.search(board)
42
+ else:
43
+ # Opponent: random move
44
+ move = random.choice(list(board.legal_moves))
45
+ print(f"Move {move_num + 1}: {move}")
46
+ state_tensor = agent.board_to_tensor(board)
47
+ policy = [0] * 4672
48
+ move_idx = list(board.legal_moves).index(move)
49
+ policy[move_idx] = 1
50
+ value = 0 # Placeholder, will be set after game
51
+ game_data.append((state_tensor, policy, value))
52
+ board.push(move)
53
+ if save_svg:
54
+ save_board_svg(board, f"{svg_prefix}_move{move_num}.svg")
55
+ move_num += 1
56
+
57
+ print(f"Game ended after {move_num} moves")
58
+ print(f"Final position:\n{board}")
59
+
60
+ penalty = 0
61
+ if board.is_game_over():
62
+ outcome = board.outcome(claim_draw=True)
63
+ if outcome:
64
+ termination = outcome.termination.name
65
+ if outcome.winner is None:
66
+ if termination == "STALEMATE":
67
+ winner_str = "Draw (stalemate)"
68
+ z = 0
69
+ elif termination == "INSUFFICIENT_MATERIAL":
70
+ winner_str = "Draw (insufficient material)"
71
+ z = 0
72
+ penalty = z
73
+ else:
74
+ winner_str = f"Draw ({termination.lower()})"
75
+ z = 0
76
+ penalty = z
77
+ elif outcome.winner:
78
+ winner_str = "White wins"
79
+ if o1_color == chess.WHITE:
80
+ z = 5
81
+ else:
82
+ z = -1 # Penalize o1 if it was black and lost
83
+ else:
84
+ winner_str = "Black wins"
85
+ if o1_color == chess.BLACK:
86
+ z = 5
87
+ else:
88
+ z = -1 # Penalize o1 if it was white and lost
89
+ print(f"Game over reason: {board.result()} ({termination})")
90
+ print(f"Result: {winner_str}")
91
+ if penalty:
92
+ print(f"Penalty applied: {penalty}")
93
+ else:
94
+ print(f"Game over reason: {board.result()} (unknown termination)")
95
+ z = 0
96
+ print(f"Penalty applied: {z}")
97
+ else:
98
+ print("Game reached move limit - applying increased penalty")
99
+ print("Result: No winner (move limit reached)")
100
+ z = -2.0
101
+ print(f"Penalty applied: {z}")
102
+
103
+ game_data = [(s, p, z) for (s, p, v) in game_data]
104
+ if save_svg:
105
+ save_board_svg(board, f"{svg_prefix}_final.svg")
106
+ return game_data
107
+
108
+ def train_step(agent, buffer, optimizer, batch_size=32):
109
+ if len(buffer.buffer) < batch_size:
110
+ return
111
+ batch = buffer.sample(batch_size)
112
+ states, target_policies, target_values = buffer.get_tensors(batch)
113
+ agent.model.train()
114
+ optimizer.zero_grad()
115
+ pred_policies, pred_values = agent.model(states)
116
+ # Policy loss (cross-entropy)
117
+ policy_loss = -torch.sum(target_policies * torch.log_softmax(pred_policies, dim=1)) / batch_size
118
+ # Value loss (MSE)
119
+ value_loss = nn.functional.mse_loss(pred_values, target_values)
120
+ loss = policy_loss + value_loss
121
+ loss.backward()
122
+ optimizer.step()
123
+ print(f"Train step: loss={loss.item():.4f} (policy={policy_loss.item():.4f}, value={value_loss.item():.4f})")
124
+
125
+ def main():
126
+ agent = Agent()
127
+ # Try to load pretrained weights if available
128
+ import os
129
+ from o1.utils import load_model
130
+ pretrained_path = "trained_agent.pth"
131
+ if os.path.exists(pretrained_path):
132
+ print(f"Loading pretrained weights from {pretrained_path}...")
133
+ load_model(agent, pretrained_path)
134
+ else:
135
+ print("No pretrained weights found. Training from scratch.")
136
+ buffer = ExperienceBuffer()
137
+ optimizer = optim.Adam(agent.model.parameters(), lr=1e-4)
138
+ num_games = 20 # Increased from 50 for more training data
139
+ global_reward = 0
140
+ for i in range(num_games):
141
+ print(f"Self-play game {i+1}")
142
+ # Only save video for the last game
143
+ save_video = (i == num_games - 1)
144
+ game_experience = self_play_game(agent, simulations=10, max_moves=300,
145
+ save_svg=save_video,
146
+ svg_prefix=f"final_game")
147
+ for exp in game_experience:
148
+ buffer.add(exp)
149
+ # Log the reward for this game (all z are the same for the game)
150
+ if game_experience:
151
+ game_reward = game_experience[0][2]
152
+ global_reward += game_reward
153
+ print(f"Reward for this game: {game_reward}")
154
+ print(f"Cumulative global reward: {global_reward}")
155
+ train_step(agent, buffer, optimizer)
156
+ print("Pipeline complete. Self-play now uses MCTS for move selection and real learning.")
157
+ # Save the trained model at the end
158
+ save_model(agent, "trained_agent.pth")
159
+ print("Model saved as trained_agent.pth")
160
+
161
+ if __name__ == "__main__":
162
+ main()
src/o1/utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for o1 agent.
3
+ Includes ELO calculation, FEN helpers, experience save/load, and more.
4
+ """
5
+ import pickle
6
+ import chess
7
+ import torch
8
+ import chess.svg
9
+ import os
10
+
11
+ def calculate_elo(rating_a, rating_b, result, k=32):
12
+ """Update ELO rating for player A given result (1=win, 0.5=draw, 0=loss)."""
13
+ expected = 1 / (1 + 10 ** ((rating_b - rating_a) / 400))
14
+ new_rating = rating_a + k * (result - expected)
15
+ return new_rating
16
+
17
+ def board_to_fen(board):
18
+ """Convert a chess.Board to FEN string."""
19
+ return board.fen()
20
+
21
+ def fen_to_board(fen):
22
+ """Convert a FEN string to chess.Board."""
23
+ return chess.Board(fen)
24
+
25
+ def save_experience(buffer, filename):
26
+ """Save experience buffer to file."""
27
+ with open(filename, 'wb') as f:
28
+ pickle.dump(buffer.buffer, f)
29
+
30
+ def load_experience(filename):
31
+ """Load experience buffer from file."""
32
+ with open(filename, 'rb') as f:
33
+ return pickle.load(f)
34
+
35
+ def save_model(agent, filename):
36
+ """Save the agent's model to a file."""
37
+ torch.save(agent.model.state_dict(), filename)
38
+
39
+ def load_model(agent, filename):
40
+ """Load the agent's model from a file."""
41
+ agent.model.load_state_dict(torch.load(filename))
42
+ agent.model.eval()
43
+
44
+ def save_board_svg(board, filename):
45
+ """Save the current board position as an SVG image."""
46
+ svg_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "game_svgs")
47
+ os.makedirs(svg_dir, exist_ok=True)
48
+ filepath = os.path.join(svg_dir, filename)
49
+ svg = chess.svg.board(board=board)
50
+ with open(filepath, 'w') as f:
51
+ f.write(svg)
52
+
53
+ # Option 3: Model Architecture Tuning
54
+ # Try deeper/smaller networks, different block types, or alternative architectures.
55
+
56
+ def try_alternative_architectures(agent_class, architectures):
57
+ """Try different model architectures and return their performance."""
58
+ results = {}
59
+ for arch in architectures:
60
+ agent = agent_class(arch=arch)
61
+ # Evaluate agent (placeholder, implement evaluation logic)
62
+ results[arch] = None # Fill with actual evaluation
63
+ return results
64
+
65
+ # Option 4: Hyperparameter Tuning
66
+ # Use grid/random search to find optimal hyperparameters.
67
+
68
+ def grid_search(train_func, param_grid):
69
+ """Perform grid search over hyperparameters."""
70
+ import itertools
71
+ keys, values = zip(*param_grid.items())
72
+ best_score = None
73
+ best_params = None
74
+ for v in itertools.product(*values):
75
+ params = dict(zip(keys, v))
76
+ score = train_func(**params)
77
+ if best_score is None or score > best_score:
78
+ best_score = score
79
+ best_params = params
80
+ return best_params, best_score
81
+
82
+ # Option 5: Regularization
83
+ # Add dropout, L2 regularization, or early stopping.
84
+
85
+ def add_regularization(model, dropout=0.2, l2=1e-4):
86
+ """Add dropout and L2 regularization to the model."""
87
+ # This is a placeholder; actual implementation depends on model code
88
+ for module in model.modules():
89
+ if hasattr(module, 'dropout'):
90
+ module.dropout.p = dropout
91
+ return model, l2
92
+
93
+ # Option 6: Cross-Validation
94
+ # Use k-fold cross-validation for robust evaluation.
95
+
96
+ def k_fold_cross_validation(train_func, k=5, *args, **kwargs):
97
+ """Perform k-fold cross-validation."""
98
+ scores = []
99
+ for i in range(k):
100
+ score = train_func(fold=i, *args, **kwargs)
101
+ scores.append(score)
102
+ return sum(scores) / len(scores)
103
+
104
+ # Option 7: Ensemble Methods
105
+ # Combine multiple models for better performance.
106
+
107
+ def ensemble_predict(models, input_tensor):
108
+ """Average predictions from multiple models."""
109
+ outputs = [model(input_tensor) for model in models]
110
+ # Assume outputs are tuples (policy, value)
111
+ avg_policy = sum([o[0] for o in outputs]) / len(outputs)
112
+ avg_value = sum([o[1] for o in outputs]) / len(outputs)
113
+ return avg_policy, avg_value
114
+
115
+ # ELO Evaluation for Model
116
+
117
+ def evaluate_model_elo(agent, opponent, num_games=20, initial_elo=1500):
118
+ """Play games between agent and opponent, return estimated ELO for agent."""
119
+ agent_elo = initial_elo
120
+ opp_elo = initial_elo
121
+ import random
122
+ import chess
123
+ from o1.mcts import MCTS
124
+ for i in range(num_games):
125
+ board = chess.Board()
126
+ mcts_agent = MCTS(agent)
127
+ mcts_opp = MCTS(opponent)
128
+ turn = random.choice([True, False])
129
+ while not board.is_game_over():
130
+ if board.turn == turn:
131
+ move = mcts_agent.search(board)
132
+ else:
133
+ move = mcts_opp.search(board)
134
+ board.push(move)
135
+ result = board.result()
136
+ if result == '1-0':
137
+ agent_score = 1 if turn else 0
138
+ elif result == '0-1':
139
+ agent_score = 0 if turn else 1
140
+ else:
141
+ agent_score = 0.5
142
+ agent_elo = calculate_elo(agent_elo, opp_elo, agent_score)
143
+ opp_elo = calculate_elo(opp_elo, agent_elo, 1 - agent_score)
144
+ return agent_elo