Create training-script-v2.py
Browse files- training-script-v2.py +348 -0
training-script-v2.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import chess
|
2 |
+
import chess.engine
|
3 |
+
import numpy as np
|
4 |
+
import tensorflow as tf
|
5 |
+
import time
|
6 |
+
import os
|
7 |
+
import datetime
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
# --- 1. Neural Network (Policy and Value Network) ---
|
11 |
+
class PolicyValueNetwork(tf.keras.Model):
|
12 |
+
def __init__(self, num_moves):
|
13 |
+
super(PolicyValueNetwork, self).__init__()
|
14 |
+
self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu', padding='same') # Removed input_shape
|
15 |
+
self.flatten = tf.keras.layers.Flatten()
|
16 |
+
self.dense_policy = tf.keras.layers.Dense(num_moves, activation='softmax', name='policy_head')
|
17 |
+
self.dense_value = tf.keras.layers.Dense(1, activation='tanh', name='value_head')
|
18 |
+
|
19 |
+
def call(self, inputs):
|
20 |
+
x = self.conv1(inputs)
|
21 |
+
x = self.flatten(x)
|
22 |
+
policy = self.dense_policy(x)
|
23 |
+
value = self.dense_value(x)
|
24 |
+
return policy, value
|
25 |
+
|
26 |
+
# --- 2. Board Representation and Preprocessing ---
|
27 |
+
def board_to_input(board):
|
28 |
+
piece_types = [chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING]
|
29 |
+
input_planes = np.zeros((8, 8, 12), dtype=np.float32)
|
30 |
+
|
31 |
+
for piece_type_index, piece_type in enumerate(piece_types):
|
32 |
+
for square in chess.SQUARES:
|
33 |
+
piece = board.piece_at(square)
|
34 |
+
if piece is not None:
|
35 |
+
if piece.piece_type == piece_type:
|
36 |
+
plane_index = piece_type_index if piece.color == chess.WHITE else piece_type_index + 6
|
37 |
+
row, col = chess.square_rank(square), chess.square_file(square)
|
38 |
+
input_planes[row, col, plane_index] = 1.0
|
39 |
+
return input_planes
|
40 |
+
|
41 |
+
def get_legal_moves_mask(board):
|
42 |
+
legal_moves = list(board.legal_moves)
|
43 |
+
move_indices = [move_to_index(move) for move in legal_moves]
|
44 |
+
|
45 |
+
# --- Defensive Check: Filter out-of-bounds indices ---
|
46 |
+
valid_move_indices = []
|
47 |
+
out_of_bounds_indices = []
|
48 |
+
for index in move_indices:
|
49 |
+
if 0 <= index < NUM_POSSIBLE_MOVES:
|
50 |
+
valid_move_indices.append(index)
|
51 |
+
else:
|
52 |
+
out_of_bounds_indices.append(index)
|
53 |
+
|
54 |
+
|
55 |
+
mask = np.zeros(NUM_POSSIBLE_MOVES, dtype=np.float32)
|
56 |
+
mask[valid_move_indices] = 1.0
|
57 |
+
return mask
|
58 |
+
|
59 |
+
# --- 3. Move Encoding/Decoding (Correct and Deterministic Implementation) ---
|
60 |
+
NUM_POSSIBLE_MOVES = 4672 # Correct value based on deterministic encoding
|
61 |
+
|
62 |
+
def move_to_index(move):
|
63 |
+
"""Standard, deterministic move to index conversion (UCI-like encoding)."""
|
64 |
+
index = 0
|
65 |
+
|
66 |
+
# Non-promotion moves (most common)
|
67 |
+
if move.promotion is None:
|
68 |
+
index = move.from_square * 64 + move.to_square # Source and target squares
|
69 |
+
|
70 |
+
# Promotion moves - use offsets to separate them from non-promotion indices
|
71 |
+
elif move.promotion == chess.KNIGHT:
|
72 |
+
index = 4096 + move.to_square # Knight promotions start after non-promotion moves
|
73 |
+
elif move.promotion == chess.BISHOP:
|
74 |
+
index = 4096 + 64 + move.to_square # Bishop promotions after Knights
|
75 |
+
elif move.promotion == chess.ROOK:
|
76 |
+
index = 4096 + 64*2 + move.to_square # Rook promotions after Bishops
|
77 |
+
elif move.promotion == chess.QUEEN:
|
78 |
+
index = 4096 + 64*3 + move.to_square # Queen promotions after Rooks
|
79 |
+
else:
|
80 |
+
raise ValueError(f"Unknown promotion piece type: {move.promotion}")
|
81 |
+
|
82 |
+
return index
|
83 |
+
|
84 |
+
def index_to_move(index, board):
|
85 |
+
"""Standard, deterministic index to move conversion (index to chess.Move)."""
|
86 |
+
|
87 |
+
if 0 <= index < 4096: # Non-promotion moves
|
88 |
+
from_square = index // 64
|
89 |
+
to_square = index % 64
|
90 |
+
promotion = None
|
91 |
+
|
92 |
+
elif 4096 <= index < 4096 + 64: # Knight promotions
|
93 |
+
from_square_rank = chess.square_rank(chess.A8) - 1 # Rank 8 for White Pawns, Rank 1 for Black Pawns, -1 for index conversion
|
94 |
+
from_square = chess.square(chess.square_file(chess.A1), from_square_rank) # Assume promotion from any file on promotion rank. Refine as needed.
|
95 |
+
to_square = index - 4096
|
96 |
+
promotion = chess.KNIGHT
|
97 |
+
|
98 |
+
elif 4096 + 64 <= index < 4096 + 64*2: # Bishop promotions
|
99 |
+
from_square_rank = chess.square_rank(chess.A8) - 1
|
100 |
+
from_square = chess.square(chess.square_file(chess.A1), from_square_rank)
|
101 |
+
to_square = index - (4096 + 64)
|
102 |
+
promotion = chess.BISHOP
|
103 |
+
|
104 |
+
elif 4096 + 64*2 <= index < 4096 + 64*3: # Rook promotions
|
105 |
+
from_square_rank = chess.square_rank(chess.A8) - 1
|
106 |
+
from_square = chess.square(chess.square_file(chess.A1), from_square_rank)
|
107 |
+
to_square = index - (4096 + 64*2)
|
108 |
+
promotion = chess.ROOK
|
109 |
+
|
110 |
+
elif 4096 + 64*3 <= index < NUM_POSSIBLE_MOVES: # Queen promotions
|
111 |
+
from_square_rank = chess.square_rank(chess.A8) - 1
|
112 |
+
from_square = chess.square(chess.square_file(chess.A1), from_square_rank)
|
113 |
+
to_square = index - (4096 + 64*3)
|
114 |
+
promotion = chess.QUEEN
|
115 |
+
|
116 |
+
else: # Invalid index
|
117 |
+
return None
|
118 |
+
|
119 |
+
move = chess.Move(from_square, to_square, promotion=promotion)
|
120 |
+
if move in board.legal_moves:
|
121 |
+
return move
|
122 |
+
return None # Move is not legal
|
123 |
+
|
124 |
+
|
125 |
+
def get_game_result_value(board):
|
126 |
+
if board.is_checkmate():
|
127 |
+
return 1 if board.turn == chess.BLACK else -1
|
128 |
+
elif board.is_stalemate() or board.is_insufficient_material() or board.is_seventyfive_moves() or board.is_fivefold_repetition() or board.is_variant_draw():
|
129 |
+
return 0
|
130 |
+
else:
|
131 |
+
return 0
|
132 |
+
|
133 |
+
# --- 4. Monte Carlo Tree Search (MCTS) ---
|
134 |
+
class MCTSNode:
|
135 |
+
def __init__(self, board, parent=None, prior_prob=0):
|
136 |
+
self.board = board.copy()
|
137 |
+
self.parent = parent
|
138 |
+
self.children = {}
|
139 |
+
self.visits = 0
|
140 |
+
self.value_sum = 0
|
141 |
+
self.prior_prob = prior_prob
|
142 |
+
self.policy_prob = 0
|
143 |
+
self.value = 0
|
144 |
+
|
145 |
+
def select_child(self, exploration_constant=1.4):
|
146 |
+
best_child = None
|
147 |
+
best_ucb = -float('inf')
|
148 |
+
for move, child in self.children.items():
|
149 |
+
ucb = child.value + exploration_constant * child.prior_prob * np.sqrt(self.visits) / (1 + child.visits)
|
150 |
+
if ucb > best_ucb:
|
151 |
+
best_ucb = ucb
|
152 |
+
best_child = child
|
153 |
+
return best_child
|
154 |
+
|
155 |
+
def expand(self, policy_probs):
|
156 |
+
legal_moves = list(self.board.legal_moves)
|
157 |
+
for move in legal_moves:
|
158 |
+
move_index = move_to_index(move)
|
159 |
+
prior_prob = policy_probs[move_index]
|
160 |
+
self.children[move] = MCTSNode(chess.Board(fen=self.board.fen()), parent=self, prior_prob=prior_prob)
|
161 |
+
|
162 |
+
def evaluate(self, policy_value_net):
|
163 |
+
input_board = board_to_input(self.board)
|
164 |
+
policy_output, value_output = policy_value_net(np.expand_dims(input_board, axis=0))
|
165 |
+
policy_probs = policy_output.numpy()[0]
|
166 |
+
value = value_output.numpy()[0][0]
|
167 |
+
|
168 |
+
legal_moves_mask = get_legal_moves_mask(self.board)
|
169 |
+
masked_policy_probs = policy_probs * legal_moves_mask
|
170 |
+
if np.sum(masked_policy_probs) > 0:
|
171 |
+
masked_policy_probs /= np.sum(masked_policy_probs)
|
172 |
+
else:
|
173 |
+
masked_policy_probs = legal_moves_mask / np.sum(legal_moves_mask)
|
174 |
+
|
175 |
+
self.policy_prob = masked_policy_probs
|
176 |
+
self.value = value
|
177 |
+
return value, masked_policy_probs
|
178 |
+
|
179 |
+
def backup(self, value):
|
180 |
+
self.visits += 1
|
181 |
+
self.value_sum += value
|
182 |
+
self.value = self.value_sum / self.visits
|
183 |
+
if self.parent:
|
184 |
+
self.parent.backup(-value)
|
185 |
+
|
186 |
+
def run_mcts(root_node, policy_value_net, num_simulations):
|
187 |
+
for _ in range(num_simulations):
|
188 |
+
node = root_node
|
189 |
+
search_path = [node]
|
190 |
+
|
191 |
+
while node.children and not node.board.is_game_over():
|
192 |
+
node = node.select_child()
|
193 |
+
search_path.append(node)
|
194 |
+
|
195 |
+
leaf_node = search_path[-1]
|
196 |
+
|
197 |
+
if not leaf_node.board.is_game_over():
|
198 |
+
value, policy_probs = leaf_node.evaluate(policy_value_net)
|
199 |
+
leaf_node.expand(policy_probs)
|
200 |
+
else:
|
201 |
+
value = get_game_result_value(leaf_node.board)
|
202 |
+
|
203 |
+
leaf_node.backup(value)
|
204 |
+
|
205 |
+
return choose_best_move_from_mcts(root_node)
|
206 |
+
|
207 |
+
def choose_best_move_from_mcts(root_node, temperature=0.0):
|
208 |
+
if temperature == 0:
|
209 |
+
best_move = max(root_node.children, key=lambda move: root_node.children[move].visits)
|
210 |
+
else:
|
211 |
+
visits = [root_node.children[move].visits for move in root_node.children]
|
212 |
+
move_probs = np.array(visits) ** (1/temperature)
|
213 |
+
move_probs = move_probs / np.sum(move_probs)
|
214 |
+
moves = list(root_node.children.keys())
|
215 |
+
best_move = np.random.choice(moves, p=move_probs)
|
216 |
+
return best_move
|
217 |
+
|
218 |
+
# --- 5. RL Engine Class ---
|
219 |
+
class RLEngine:
|
220 |
+
def __init__(self, policy_value_net, num_simulations_per_move=100):
|
221 |
+
self.policy_value_net = policy_value_net
|
222 |
+
self.num_simulations_per_move = num_simulations_per_move
|
223 |
+
|
224 |
+
def choose_move(self, board):
|
225 |
+
root_node = MCTSNode(board)
|
226 |
+
best_move = run_mcts(root_node, self.policy_value_net, self.num_simulations_per_move)
|
227 |
+
return best_move
|
228 |
+
|
229 |
+
# --- 6. Training Functions ---
|
230 |
+
def self_play_game(engine, model, num_simulations):
|
231 |
+
game_history = []
|
232 |
+
board = chess.Board()
|
233 |
+
while not board.is_game_over():
|
234 |
+
root_node = MCTSNode(board)
|
235 |
+
run_mcts(root_node, model, num_simulations)
|
236 |
+
|
237 |
+
policy_targets = create_policy_targets_from_mcts_visits(root_node)
|
238 |
+
game_history.append((board.fen(), policy_targets))
|
239 |
+
|
240 |
+
best_move = choose_best_move_from_mcts(root_node, temperature=0.8) # Exploration temperature
|
241 |
+
board.push(best_move)
|
242 |
+
|
243 |
+
game_result = get_game_result_value(board)
|
244 |
+
|
245 |
+
for i in range(len(game_history)):
|
246 |
+
fen, policy_target = game_history[i]
|
247 |
+
game_history[i] = (fen, policy_target, game_result if board.turn == chess.WHITE else -game_result)
|
248 |
+
return game_history
|
249 |
+
|
250 |
+
def create_policy_targets_from_mcts_visits(root_node):
|
251 |
+
policy_targets = np.zeros(NUM_POSSIBLE_MOVES, dtype=np.float32)
|
252 |
+
for move, child_node in root_node.children.items():
|
253 |
+
move_index = move_to_index(move)
|
254 |
+
policy_targets[move_index] = child_node.visits
|
255 |
+
policy_targets /= np.sum(policy_targets)
|
256 |
+
return policy_targets
|
257 |
+
|
258 |
+
def train_step(model, board_inputs, policy_targets, value_targets, optimizer):
|
259 |
+
with tf.GradientTape() as tape:
|
260 |
+
policy_outputs, value_outputs = model(board_inputs)
|
261 |
+
policy_loss = tf.keras.losses.CategoricalCrossentropy()(policy_targets, policy_outputs)
|
262 |
+
value_loss = tf.keras.losses.MeanSquaredError()(value_targets, value_outputs)
|
263 |
+
total_loss = policy_loss + value_loss
|
264 |
+
gradients = tape.gradient(total_loss, model.trainable_variables)
|
265 |
+
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
|
266 |
+
return total_loss, policy_loss, value_loss
|
267 |
+
|
268 |
+
def train_network(model, game_histories, optimizer, epochs=10, batch_size=32):
|
269 |
+
all_board_inputs = []
|
270 |
+
all_policy_targets = []
|
271 |
+
all_value_targets = []
|
272 |
+
|
273 |
+
for game_history in game_histories:
|
274 |
+
for fen, policy_target, game_result in game_history:
|
275 |
+
board = chess.Board(fen)
|
276 |
+
all_board_inputs.append(board_to_input(board))
|
277 |
+
all_policy_targets.append(policy_target)
|
278 |
+
all_value_targets.append(np.array([game_result]))
|
279 |
+
|
280 |
+
all_board_inputs = np.array(all_board_inputs)
|
281 |
+
all_policy_targets = np.array(all_policy_targets)
|
282 |
+
all_value_targets = np.array(all_value_targets)
|
283 |
+
|
284 |
+
dataset = tf.data.Dataset.from_tensor_slices((all_board_inputs, all_policy_targets, all_value_targets))
|
285 |
+
dataset = dataset.shuffle(buffer_size=len(all_board_inputs)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
|
286 |
+
|
287 |
+
for epoch in range(epochs):
|
288 |
+
print(f"Epoch {epoch+1}/{epochs}")
|
289 |
+
for batch_inputs, batch_policy_targets, batch_value_targets in dataset:
|
290 |
+
loss, p_loss, v_loss = train_step(model, batch_inputs, batch_policy_targets, batch_value_targets, optimizer)
|
291 |
+
print(f" Loss: {loss:.4f}, Policy Loss: {p_loss:.4f}, Value Loss: {v_loss:.4f}")
|
292 |
+
|
293 |
+
# --- 7. Main Training Execution in Colab ---
|
294 |
+
if __name__ == "__main__":
|
295 |
+
# --- Check GPU Availability in Colab ---
|
296 |
+
if tf.config.list_physical_devices('GPU'):
|
297 |
+
print("\n\nGPU is available and will be used for training.\n\n")
|
298 |
+
gpu_device = '/GPU:0' # Use GPU 0 if available
|
299 |
+
else:
|
300 |
+
print("\n\nGPU is not available. Training will use CPU (may be slow).\n\n")
|
301 |
+
gpu_device = '/CPU:0'
|
302 |
+
|
303 |
+
with tf.device(gpu_device): # Explicitly place operations on GPU (if available)
|
304 |
+
# Initialize Neural Network, Engine, and Optimizer
|
305 |
+
policy_value_net = PolicyValueNetwork(NUM_POSSIBLE_MOVES)
|
306 |
+
engine = RLEngine(policy_value_net, num_simulations_per_move=100)
|
307 |
+
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
|
308 |
+
|
309 |
+
# --- Training Parameters ---
|
310 |
+
num_self_play_games = 50 # Adjust for longer training
|
311 |
+
epochs = 5 # Adjust for longer training
|
312 |
+
|
313 |
+
# --- Run Self-Play and Training ---
|
314 |
+
game_histories = []
|
315 |
+
start_time = time.time()
|
316 |
+
|
317 |
+
# --- Model Save Directory in Colab ---
|
318 |
+
MODEL_SAVE_DIR = "models_colab" # Directory to save model in Colab
|
319 |
+
os.makedirs(MODEL_SAVE_DIR, exist_ok=True) # Create directory if it doesn't exist
|
320 |
+
|
321 |
+
for i in range(num_self_play_games):
|
322 |
+
print(f"Self-play game {i+1}/{num_self_play_games} \n")
|
323 |
+
game_history = self_play_game(engine, policy_value_net, num_simulations=50) # Reduced sims for faster games
|
324 |
+
game_histories.append(game_history)
|
325 |
+
|
326 |
+
train_network(policy_value_net, game_histories, optimizer, epochs=epochs)
|
327 |
+
|
328 |
+
end_time = time.time()
|
329 |
+
training_time = end_time - start_time
|
330 |
+
print(f"\n\n ---- Training completed in {training_time:.2f} seconds. ---- \n")
|
331 |
+
|
332 |
+
# --- Save the trained model (architecture + weights) in SavedModel format ---
|
333 |
+
current_datetime = datetime.datetime.now()
|
334 |
+
model_version_str = current_datetime.strftime("%Y-%m-%d-%H%M") # Added hour and minute for uniqueness
|
335 |
+
model_save_path = os.path.join(MODEL_SAVE_DIR, f"StockZero-{model_version_str}.weights.h5") # Versioned filename
|
336 |
+
policy_value_net.save_weights(model_save_path) # Saves model weights
|
337 |
+
print(f"Trained model weights saved to '{model_save_path}' in '{MODEL_SAVE_DIR}' directory in Colab.")
|
338 |
+
|
339 |
+
# --- Download the saved model (for use outside Colab) ---
|
340 |
+
# --- (Optional: Uncomment to download the saved model as a zip file) ---
|
341 |
+
import shutil
|
342 |
+
zip_file_path = f"StockZero-{model_version_str}"
|
343 |
+
shutil.make_archive(zip_file_path, 'zip', MODEL_SAVE_DIR) # Create zip archive
|
344 |
+
print(f"Model directory zipped to '{zip_file_path}'. Download this file.")
|
345 |
+
from google.colab import files
|
346 |
+
files.download(f"{zip_file_path}.zip") # Trigger download in Colab
|
347 |
+
|
348 |
+
print("\n\n ----- Training finished. ------- \n\n")
|