nirajandhakal commited on
Commit
334352f
·
verified ·
1 Parent(s): ec6db90

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +205 -3
README.md CHANGED
@@ -1,3 +1,205 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: mit
4
+ tags:
5
+ - chess
6
+ - reinforcement-learning
7
+ - deep-learning
8
+ - tensorflow
9
+ - pytorch
10
+ - onnx
11
+ - tflite
12
+ - self-play
13
+ - mcts
14
+ ---
15
+
16
+ # StockZero: A Self-Play Reinforcement Learning Chess Engine
17
+
18
+ This model card describes **StockZero**, a self-play reinforcement learning chess engine trained using TensorFlow/Keras. It combines a policy-value neural network with Monte Carlo Tree Search (MCTS) for decision making. StockZero serves as an educational example of applying deep RL to the game of chess. This model card also includes information about the model's converted formats and their usage.
19
+
20
+ ## Model Details
21
+
22
+ ### Model Description
23
+
24
+ StockZero learns to play chess by playing against itself. The core component is a neural network that takes a chess board state as input and outputs:
25
+
26
+ 1. **Policy**: A probability distribution over all legal moves, indicating which move the model thinks is best.
27
+ 2. **Value**: An estimation of the win/loss probability from the current player's perspective.
28
+
29
+ The model is trained using self-play data generated through MCTS, which guides the engine to explore promising game states.
30
+
31
+ ### Input
32
+
33
+ The model takes a chess board as input, represented as a 8x8x12 NumPy array. Each of the 12 channels in the input represent a specific piece type (Pawn, Knight, Bishop, Rook, Queen, King) for both white and black players, where each layer contains binary values.
34
+
35
+ ### Output
36
+
37
+ The model outputs two vectors:
38
+
39
+ 1. **Policy**: A probability distribution over `NUM_POSSIBLE_MOVES=4672` representing the probability of making each move, obtained using `softmax` activation.
40
+ 2. **Value**: A single scalar value indicating win/loss probability from current player’s perspective, ranging from -1 (loss) to 1 (win), obtained using `tanh` activation.
41
+
42
+ ### Model Architecture
43
+
44
+ The neural network architecture consists of:
45
+
46
+ * One Convolutional Layer: `Conv2D(32, 3, activation='relu', padding='same')`
47
+ * Flatten Layer: `Flatten()`
48
+ * Two Dense Layers:
49
+ * `Dense(NUM_POSSIBLE_MOVES, activation='softmax', name='policy_head')` for move probabilities
50
+ * `Dense(1, activation='tanh', name='value_head')` for win/loss estimation
51
+
52
+ ### Training Data
53
+
54
+ The model was trained on data generated from self-play, playing chess games against itself, with the generated self-play games then used to train the network iteratively. This process is similar to the AlphaZero approach.
55
+
56
+ ### Training Procedure
57
+
58
+ 1. **Self-Play**: The engine plays against itself using MCTS to make move decisions, generating game trajectories.
59
+ 2. **Data Collection**: During the self-play, the board state and MCTS visit counts are recorded as the target policy. The final game results are saved as the target value.
60
+ 3. **Training**: The model learns from the self-play data using a combination of categorical cross-entropy for the policy and mean squared error for the value.
61
+
62
+ The optimizer used during training is **Adam** with a learning rate of 0.001.
63
+
64
+ ### Training parameters
65
+
66
+ * `num_self_play_games = 50`
67
+ * `epochs = 5`
68
+ * `num_simulations_per_move=100`
69
+
70
+ ### Model Versions
71
+
72
+ This model has been converted into several formats for flexible deployment:
73
+
74
+ * **TensorFlow SavedModel**: A directory containing model architecture and weights, allowing for native tensorflow usage.
75
+ * **Keras Model (.keras)**: A full Keras model with architecture and weights, suitable for Keras environment.
76
+ * **Keras Weights (.h5)**: Only model weights that can be loaded to an existing `PolicyValueNetwork` in Keras/TensorFlow
77
+ * **PyTorch Model (.pth)**: Full PyTorch Model equivalent (architecture and weights).
78
+ * **PyTorch Weights (.pth)**: Model weights that can be loaded to `PyTorchPolicyValueNetwork`.
79
+ * **ONNX (.onnx)**: A standard format for interoperability with various machine learning frameworks.
80
+ * **TensorFlow Lite (.tflite)**: An efficient model format for mobile and embedded devices.
81
+ * **Raw Binary (.bin)**: Raw byte representation of all model weights, for use in custom implementations.
82
+ * **NumPy Array (.npz)**: Model weights saved as individual numpy arrays, which can be easily loaded in many environments.
83
+
84
+ The model files are versioned based on the training time to maintain uniqueness, as model names are added to the filename.
85
+ For example : `StockZero-2025-03-24-1727.weights.h5` or `converted_models-202503241727.zip`.
86
+
87
+ ### Intended Use
88
+
89
+ The model is intended for research, experimentation, and education purposes. Potential applications include:
90
+
91
+ * Studying reinforcement learning algorithms applied to complex games.
92
+ * Developing chess playing AI.
93
+ * Serving as a base model for fine-tuning and further research.
94
+ * Deploying as a lightweight engine on constrained hardware (TFLite).
95
+ * Using in non-TensorFlow environments (PyTorch, ONNX).
96
+
97
+ ### Limitations
98
+
99
+ * The model is not intended to compete against top-level chess engines.
100
+ * The training data is limited to a small number of self-play games (50 games), therefore the strength of the engine is limited.
101
+ * The model is trained on a single GPU, so longer training may require multi GPU support or longer runtime.
102
+
103
+ ## How to Use
104
+
105
+ ### Training
106
+ 1. Upload `training_code.py` to Google Colab.
107
+ 2. Run the script to train the model on Google Colab.
108
+ 3. Download the zip file of the trained weights and model, that is provided automatically after the training is complete.
109
+
110
+ ### Model Conversion
111
+
112
+ 1. Place `conversion_script.py` in Google Colab, and make sure the saved weights are in the correct location.
113
+ 2. Run the script to create model files of different formats inside a folder `converted_models`.
114
+ 3. Download the zip file containing all converted models using the automatic Colab download, which is triggered at the end of the script.
115
+
116
+ ### Inference
117
+
118
+ To use the model for inference, load the model weights into an instance of `PolicyValueNetwork` (or its PyTorch equivalent) and use the `board_to_input` and `get_legal_moves_mask` functions to prepare the input. The following code shows how to make predictions:
119
+
120
+ ```python
121
+ import chess
122
+ import numpy as np
123
+ import tensorflow as tf
124
+
125
+ class PolicyValueNetwork(tf.keras.Model):
126
+ def __init__(self, num_moves):
127
+ super(PolicyValueNetwork, self).__init__()
128
+ self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu', padding='same')
129
+ self.flatten = tf.keras.layers.Flatten()
130
+ self.dense_policy = tf.keras.layers.Dense(num_moves, activation='softmax', name='policy_head')
131
+ self.dense_value = tf.keras.layers.Dense(1, activation='tanh', name='value_head')
132
+
133
+ def call(self, inputs):
134
+ x = self.conv1(inputs)
135
+ x = self.flatten(x)
136
+ policy = self.dense_policy(x)
137
+ value = self.dense_value(x)
138
+ return policy, value
139
+
140
+ NUM_POSSIBLE_MOVES = 4672
141
+ NUM_INPUT_PLANES = 12
142
+
143
+ def board_to_input(board):
144
+ piece_types = [chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING]
145
+ input_planes = np.zeros((8, 8, NUM_INPUT_PLANES), dtype=np.float32)
146
+ for piece_type_index, piece_type in enumerate(piece_types):
147
+ for square in chess.SQUARES:
148
+ piece = board.piece_at(square)
149
+ if piece is not None:
150
+ if piece.piece_type == piece_type:
151
+ plane_index = piece_type_index if piece.color == chess.WHITE else piece_type_index + 6
152
+ row, col = chess.square_rank(square), chess.square_file(square)
153
+ input_planes[row, col, plane_index] = 1.0
154
+ return input_planes
155
+
156
+ def get_legal_moves_mask(board):
157
+ legal_moves = list(board.legal_moves)
158
+ move_indices = [move_to_index(move) for move in legal_moves]
159
+ mask = np.zeros(NUM_POSSIBLE_MOVES, dtype=np.float32)
160
+ mask[move_indices] = 1.0
161
+ return mask
162
+
163
+ def move_to_index(move):
164
+ index = 0
165
+ if move.promotion is None:
166
+ index = move.from_square * 64 + move.to_square
167
+ elif move.promotion == chess.KNIGHT:
168
+ index = 4096 + move.to_square
169
+ elif move.promotion == chess.BISHOP:
170
+ index = 4096 + 64 + move.to_square
171
+ elif move.promotion == chess.ROOK:
172
+ index = 4096 + 64*2 + move.to_square
173
+ elif move.promotion == chess.QUEEN:
174
+ index = 4096 + 64*3 + move.to_square
175
+ else:
176
+ raise ValueError(f"Unknown promotion piece type: {move.promotion}")
177
+ return index
178
+
179
+ # Load Model weights
180
+ policy_value_net = PolicyValueNetwork(NUM_POSSIBLE_MOVES)
181
+ # dummy input for building network
182
+ dummy_input = tf.random.normal((1, 8, 8, NUM_INPUT_PLANES))
183
+ policy, value = policy_value_net(dummy_input)
184
+
185
+ # Replace 'path/to/your/model.weights.h5' with the actual path to your .h5 weights
186
+ model_path = "path/to/your/model.weights.h5"
187
+ policy_value_net.load_weights(model_path)
188
+
189
+ # Example usage
190
+ board = chess.Board()
191
+ input_data = board_to_input(board)
192
+ legal_moves_mask = get_legal_moves_mask(board)
193
+ input_data = np.expand_dims(input_data, axis=0) # Add batch dimension
194
+
195
+ policy_output, value_output = policy_value_net(input_data)
196
+ policy_output = policy_output.numpy()
197
+ value_output = value_output.numpy()
198
+ masked_policy_probs = policy_output[0] * legal_moves_mask # Apply legal move mask
199
+
200
+ # Normalize policy probabilities, make it zero if sum of probabilities is zero.
201
+ if np.sum(masked_policy_probs) > 0:
202
+ masked_policy_probs /= np.sum(masked_policy_probs)
203
+
204
+ print("Policy Output:", masked_policy_probs)
205
+ print("Value Output:", value_output)