nirajandhakal commited on
Commit
450608e
·
verified ·
1 Parent(s): 3bc4389

Create training-script-v2.py

Browse files
Files changed (1) hide show
  1. training-script-v2.py +348 -0
training-script-v2.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chess
2
+ import chess.engine
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ import time
6
+ import os
7
+ import datetime
8
+ import numpy as np
9
+
10
+ # --- 1. Neural Network (Policy and Value Network) ---
11
+ class PolicyValueNetwork(tf.keras.Model):
12
+ def __init__(self, num_moves):
13
+ super(PolicyValueNetwork, self).__init__()
14
+ self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu', padding='same') # Removed input_shape
15
+ self.flatten = tf.keras.layers.Flatten()
16
+ self.dense_policy = tf.keras.layers.Dense(num_moves, activation='softmax', name='policy_head')
17
+ self.dense_value = tf.keras.layers.Dense(1, activation='tanh', name='value_head')
18
+
19
+ def call(self, inputs):
20
+ x = self.conv1(inputs)
21
+ x = self.flatten(x)
22
+ policy = self.dense_policy(x)
23
+ value = self.dense_value(x)
24
+ return policy, value
25
+
26
+ # --- 2. Board Representation and Preprocessing ---
27
+ def board_to_input(board):
28
+ piece_types = [chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING]
29
+ input_planes = np.zeros((8, 8, 12), dtype=np.float32)
30
+
31
+ for piece_type_index, piece_type in enumerate(piece_types):
32
+ for square in chess.SQUARES:
33
+ piece = board.piece_at(square)
34
+ if piece is not None:
35
+ if piece.piece_type == piece_type:
36
+ plane_index = piece_type_index if piece.color == chess.WHITE else piece_type_index + 6
37
+ row, col = chess.square_rank(square), chess.square_file(square)
38
+ input_planes[row, col, plane_index] = 1.0
39
+ return input_planes
40
+
41
+ def get_legal_moves_mask(board):
42
+ legal_moves = list(board.legal_moves)
43
+ move_indices = [move_to_index(move) for move in legal_moves]
44
+
45
+ # --- Defensive Check: Filter out-of-bounds indices ---
46
+ valid_move_indices = []
47
+ out_of_bounds_indices = []
48
+ for index in move_indices:
49
+ if 0 <= index < NUM_POSSIBLE_MOVES:
50
+ valid_move_indices.append(index)
51
+ else:
52
+ out_of_bounds_indices.append(index)
53
+
54
+
55
+ mask = np.zeros(NUM_POSSIBLE_MOVES, dtype=np.float32)
56
+ mask[valid_move_indices] = 1.0
57
+ return mask
58
+
59
+ # --- 3. Move Encoding/Decoding (Correct and Deterministic Implementation) ---
60
+ NUM_POSSIBLE_MOVES = 4672 # Correct value based on deterministic encoding
61
+
62
+ def move_to_index(move):
63
+ """Standard, deterministic move to index conversion (UCI-like encoding)."""
64
+ index = 0
65
+
66
+ # Non-promotion moves (most common)
67
+ if move.promotion is None:
68
+ index = move.from_square * 64 + move.to_square # Source and target squares
69
+
70
+ # Promotion moves - use offsets to separate them from non-promotion indices
71
+ elif move.promotion == chess.KNIGHT:
72
+ index = 4096 + move.to_square # Knight promotions start after non-promotion moves
73
+ elif move.promotion == chess.BISHOP:
74
+ index = 4096 + 64 + move.to_square # Bishop promotions after Knights
75
+ elif move.promotion == chess.ROOK:
76
+ index = 4096 + 64*2 + move.to_square # Rook promotions after Bishops
77
+ elif move.promotion == chess.QUEEN:
78
+ index = 4096 + 64*3 + move.to_square # Queen promotions after Rooks
79
+ else:
80
+ raise ValueError(f"Unknown promotion piece type: {move.promotion}")
81
+
82
+ return index
83
+
84
+ def index_to_move(index, board):
85
+ """Standard, deterministic index to move conversion (index to chess.Move)."""
86
+
87
+ if 0 <= index < 4096: # Non-promotion moves
88
+ from_square = index // 64
89
+ to_square = index % 64
90
+ promotion = None
91
+
92
+ elif 4096 <= index < 4096 + 64: # Knight promotions
93
+ from_square_rank = chess.square_rank(chess.A8) - 1 # Rank 8 for White Pawns, Rank 1 for Black Pawns, -1 for index conversion
94
+ from_square = chess.square(chess.square_file(chess.A1), from_square_rank) # Assume promotion from any file on promotion rank. Refine as needed.
95
+ to_square = index - 4096
96
+ promotion = chess.KNIGHT
97
+
98
+ elif 4096 + 64 <= index < 4096 + 64*2: # Bishop promotions
99
+ from_square_rank = chess.square_rank(chess.A8) - 1
100
+ from_square = chess.square(chess.square_file(chess.A1), from_square_rank)
101
+ to_square = index - (4096 + 64)
102
+ promotion = chess.BISHOP
103
+
104
+ elif 4096 + 64*2 <= index < 4096 + 64*3: # Rook promotions
105
+ from_square_rank = chess.square_rank(chess.A8) - 1
106
+ from_square = chess.square(chess.square_file(chess.A1), from_square_rank)
107
+ to_square = index - (4096 + 64*2)
108
+ promotion = chess.ROOK
109
+
110
+ elif 4096 + 64*3 <= index < NUM_POSSIBLE_MOVES: # Queen promotions
111
+ from_square_rank = chess.square_rank(chess.A8) - 1
112
+ from_square = chess.square(chess.square_file(chess.A1), from_square_rank)
113
+ to_square = index - (4096 + 64*3)
114
+ promotion = chess.QUEEN
115
+
116
+ else: # Invalid index
117
+ return None
118
+
119
+ move = chess.Move(from_square, to_square, promotion=promotion)
120
+ if move in board.legal_moves:
121
+ return move
122
+ return None # Move is not legal
123
+
124
+
125
+ def get_game_result_value(board):
126
+ if board.is_checkmate():
127
+ return 1 if board.turn == chess.BLACK else -1
128
+ elif board.is_stalemate() or board.is_insufficient_material() or board.is_seventyfive_moves() or board.is_fivefold_repetition() or board.is_variant_draw():
129
+ return 0
130
+ else:
131
+ return 0
132
+
133
+ # --- 4. Monte Carlo Tree Search (MCTS) ---
134
+ class MCTSNode:
135
+ def __init__(self, board, parent=None, prior_prob=0):
136
+ self.board = board.copy()
137
+ self.parent = parent
138
+ self.children = {}
139
+ self.visits = 0
140
+ self.value_sum = 0
141
+ self.prior_prob = prior_prob
142
+ self.policy_prob = 0
143
+ self.value = 0
144
+
145
+ def select_child(self, exploration_constant=1.4):
146
+ best_child = None
147
+ best_ucb = -float('inf')
148
+ for move, child in self.children.items():
149
+ ucb = child.value + exploration_constant * child.prior_prob * np.sqrt(self.visits) / (1 + child.visits)
150
+ if ucb > best_ucb:
151
+ best_ucb = ucb
152
+ best_child = child
153
+ return best_child
154
+
155
+ def expand(self, policy_probs):
156
+ legal_moves = list(self.board.legal_moves)
157
+ for move in legal_moves:
158
+ move_index = move_to_index(move)
159
+ prior_prob = policy_probs[move_index]
160
+ self.children[move] = MCTSNode(chess.Board(fen=self.board.fen()), parent=self, prior_prob=prior_prob)
161
+
162
+ def evaluate(self, policy_value_net):
163
+ input_board = board_to_input(self.board)
164
+ policy_output, value_output = policy_value_net(np.expand_dims(input_board, axis=0))
165
+ policy_probs = policy_output.numpy()[0]
166
+ value = value_output.numpy()[0][0]
167
+
168
+ legal_moves_mask = get_legal_moves_mask(self.board)
169
+ masked_policy_probs = policy_probs * legal_moves_mask
170
+ if np.sum(masked_policy_probs) > 0:
171
+ masked_policy_probs /= np.sum(masked_policy_probs)
172
+ else:
173
+ masked_policy_probs = legal_moves_mask / np.sum(legal_moves_mask)
174
+
175
+ self.policy_prob = masked_policy_probs
176
+ self.value = value
177
+ return value, masked_policy_probs
178
+
179
+ def backup(self, value):
180
+ self.visits += 1
181
+ self.value_sum += value
182
+ self.value = self.value_sum / self.visits
183
+ if self.parent:
184
+ self.parent.backup(-value)
185
+
186
+ def run_mcts(root_node, policy_value_net, num_simulations):
187
+ for _ in range(num_simulations):
188
+ node = root_node
189
+ search_path = [node]
190
+
191
+ while node.children and not node.board.is_game_over():
192
+ node = node.select_child()
193
+ search_path.append(node)
194
+
195
+ leaf_node = search_path[-1]
196
+
197
+ if not leaf_node.board.is_game_over():
198
+ value, policy_probs = leaf_node.evaluate(policy_value_net)
199
+ leaf_node.expand(policy_probs)
200
+ else:
201
+ value = get_game_result_value(leaf_node.board)
202
+
203
+ leaf_node.backup(value)
204
+
205
+ return choose_best_move_from_mcts(root_node)
206
+
207
+ def choose_best_move_from_mcts(root_node, temperature=0.0):
208
+ if temperature == 0:
209
+ best_move = max(root_node.children, key=lambda move: root_node.children[move].visits)
210
+ else:
211
+ visits = [root_node.children[move].visits for move in root_node.children]
212
+ move_probs = np.array(visits) ** (1/temperature)
213
+ move_probs = move_probs / np.sum(move_probs)
214
+ moves = list(root_node.children.keys())
215
+ best_move = np.random.choice(moves, p=move_probs)
216
+ return best_move
217
+
218
+ # --- 5. RL Engine Class ---
219
+ class RLEngine:
220
+ def __init__(self, policy_value_net, num_simulations_per_move=100):
221
+ self.policy_value_net = policy_value_net
222
+ self.num_simulations_per_move = num_simulations_per_move
223
+
224
+ def choose_move(self, board):
225
+ root_node = MCTSNode(board)
226
+ best_move = run_mcts(root_node, self.policy_value_net, self.num_simulations_per_move)
227
+ return best_move
228
+
229
+ # --- 6. Training Functions ---
230
+ def self_play_game(engine, model, num_simulations):
231
+ game_history = []
232
+ board = chess.Board()
233
+ while not board.is_game_over():
234
+ root_node = MCTSNode(board)
235
+ run_mcts(root_node, model, num_simulations)
236
+
237
+ policy_targets = create_policy_targets_from_mcts_visits(root_node)
238
+ game_history.append((board.fen(), policy_targets))
239
+
240
+ best_move = choose_best_move_from_mcts(root_node, temperature=0.8) # Exploration temperature
241
+ board.push(best_move)
242
+
243
+ game_result = get_game_result_value(board)
244
+
245
+ for i in range(len(game_history)):
246
+ fen, policy_target = game_history[i]
247
+ game_history[i] = (fen, policy_target, game_result if board.turn == chess.WHITE else -game_result)
248
+ return game_history
249
+
250
+ def create_policy_targets_from_mcts_visits(root_node):
251
+ policy_targets = np.zeros(NUM_POSSIBLE_MOVES, dtype=np.float32)
252
+ for move, child_node in root_node.children.items():
253
+ move_index = move_to_index(move)
254
+ policy_targets[move_index] = child_node.visits
255
+ policy_targets /= np.sum(policy_targets)
256
+ return policy_targets
257
+
258
+ def train_step(model, board_inputs, policy_targets, value_targets, optimizer):
259
+ with tf.GradientTape() as tape:
260
+ policy_outputs, value_outputs = model(board_inputs)
261
+ policy_loss = tf.keras.losses.CategoricalCrossentropy()(policy_targets, policy_outputs)
262
+ value_loss = tf.keras.losses.MeanSquaredError()(value_targets, value_outputs)
263
+ total_loss = policy_loss + value_loss
264
+ gradients = tape.gradient(total_loss, model.trainable_variables)
265
+ optimizer.apply_gradients(zip(gradients, model.trainable_variables))
266
+ return total_loss, policy_loss, value_loss
267
+
268
+ def train_network(model, game_histories, optimizer, epochs=10, batch_size=32):
269
+ all_board_inputs = []
270
+ all_policy_targets = []
271
+ all_value_targets = []
272
+
273
+ for game_history in game_histories:
274
+ for fen, policy_target, game_result in game_history:
275
+ board = chess.Board(fen)
276
+ all_board_inputs.append(board_to_input(board))
277
+ all_policy_targets.append(policy_target)
278
+ all_value_targets.append(np.array([game_result]))
279
+
280
+ all_board_inputs = np.array(all_board_inputs)
281
+ all_policy_targets = np.array(all_policy_targets)
282
+ all_value_targets = np.array(all_value_targets)
283
+
284
+ dataset = tf.data.Dataset.from_tensor_slices((all_board_inputs, all_policy_targets, all_value_targets))
285
+ dataset = dataset.shuffle(buffer_size=len(all_board_inputs)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
286
+
287
+ for epoch in range(epochs):
288
+ print(f"Epoch {epoch+1}/{epochs}")
289
+ for batch_inputs, batch_policy_targets, batch_value_targets in dataset:
290
+ loss, p_loss, v_loss = train_step(model, batch_inputs, batch_policy_targets, batch_value_targets, optimizer)
291
+ print(f"  Loss: {loss:.4f}, Policy Loss: {p_loss:.4f}, Value Loss: {v_loss:.4f}")
292
+
293
+ # --- 7. Main Training Execution in Colab ---
294
+ if __name__ == "__main__":
295
+ # --- Check GPU Availability in Colab ---
296
+ if tf.config.list_physical_devices('GPU'):
297
+ print("\n\nGPU is available and will be used for training.\n\n")
298
+ gpu_device = '/GPU:0' # Use GPU 0 if available
299
+ else:
300
+ print("\n\nGPU is not available. Training will use CPU (may be slow).\n\n")
301
+ gpu_device = '/CPU:0'
302
+
303
+ with tf.device(gpu_device): # Explicitly place operations on GPU (if available)
304
+ # Initialize Neural Network, Engine, and Optimizer
305
+ policy_value_net = PolicyValueNetwork(NUM_POSSIBLE_MOVES)
306
+ engine = RLEngine(policy_value_net, num_simulations_per_move=100)
307
+ optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
308
+
309
+ # --- Training Parameters ---
310
+ num_self_play_games = 50 # Adjust for longer training
311
+ epochs = 5 # Adjust for longer training
312
+
313
+ # --- Run Self-Play and Training ---
314
+ game_histories = []
315
+ start_time = time.time()
316
+
317
+ # --- Model Save Directory in Colab ---
318
+ MODEL_SAVE_DIR = "models_colab" # Directory to save model in Colab
319
+ os.makedirs(MODEL_SAVE_DIR, exist_ok=True) # Create directory if it doesn't exist
320
+
321
+ for i in range(num_self_play_games):
322
+ print(f"Self-play game {i+1}/{num_self_play_games} \n")
323
+ game_history = self_play_game(engine, policy_value_net, num_simulations=50) # Reduced sims for faster games
324
+ game_histories.append(game_history)
325
+
326
+ train_network(policy_value_net, game_histories, optimizer, epochs=epochs)
327
+
328
+ end_time = time.time()
329
+ training_time = end_time - start_time
330
+ print(f"\n\n ---- Training completed in {training_time:.2f} seconds. ---- \n")
331
+
332
+ # --- Save the trained model (architecture + weights) in SavedModel format ---
333
+ current_datetime = datetime.datetime.now()
334
+ model_version_str = current_datetime.strftime("%Y-%m-%d-%H%M") # Added hour and minute for uniqueness
335
+ model_save_path = os.path.join(MODEL_SAVE_DIR, f"StockZero-{model_version_str}.weights.h5") # Versioned filename
336
+ policy_value_net.save_weights(model_save_path) # Saves model weights
337
+ print(f"Trained model weights saved to '{model_save_path}' in '{MODEL_SAVE_DIR}' directory in Colab.")
338
+
339
+ # --- Download the saved model (for use outside Colab) ---
340
+ # --- (Optional: Uncomment to download the saved model as a zip file) ---
341
+ import shutil
342
+ zip_file_path = f"StockZero-{model_version_str}"
343
+ shutil.make_archive(zip_file_path, 'zip', MODEL_SAVE_DIR) # Create zip archive
344
+ print(f"Model directory zipped to '{zip_file_path}'. Download this file.")
345
+ from google.colab import files
346
+ files.download(f"{zip_file_path}.zip") # Trigger download in Colab
347
+
348
+ print("\n\n ----- Training finished. ------- \n\n")