Spaces:
Build error
Build error
| import os | |
| import sys | |
| import torch | |
| from torch import nn | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| import data | |
| from model import ChessModel | |
| def train(): | |
| device_string = "cuda" if torch.cuda.is_available() else "cpu" | |
| device = torch.device(device_string) | |
| model = ChessModel(256).to(torch.float32).to(device) | |
| opt = torch.optim.Adam(model.parameters()) | |
| reconstruction_loss_fn = nn.CrossEntropyLoss().to(torch.float32).to(device) | |
| popularity_loss_fn = nn.L1Loss().to(torch.float32).to(device) | |
| evaluation_loss_fn = nn.L1Loss().to(torch.float32).to(device) | |
| data_loader = DataLoader(data.LichessPuzzleDataset(cap_data=65536), batch_size=64, num_workers=1) # 1 to avoid threading madness. | |
| num_epochs = 100 | |
| for epoch in range(num_epochs): | |
| model.train() | |
| total_reconstruction_loss = 0.0 | |
| total_popularity_loss = 0.0 | |
| total_evaluation_loss = 0.0 | |
| total_batch_loss = 0.0 | |
| num_batches = 0 | |
| for batch_idx, (board_vec, popularity, evaluation) in tqdm(enumerate(data_loader)): | |
| board_vec = board_vec.to(torch.float32).to(device) # [batch_size x 903] | |
| popularity = popularity.to(torch.float32).to(device).unsqueeze(1) # enforce [batch_size, 1] | |
| evaluation = evaluation.to(torch.float32).to(device).unsqueeze(1) | |
| _embedding, predicted_popularity, predicted_evaluation, predicted_board_vec = model(board_vec) | |
| reconstruction_loss = reconstruction_loss_fn(predicted_board_vec, board_vec) | |
| popularity_loss = popularity_loss_fn(predicted_popularity, popularity) | |
| evaluation_loss = evaluation_loss_fn(predicted_evaluation, evaluation) | |
| total_loss = reconstruction_loss + popularity_loss + evaluation_loss | |
| opt.zero_grad() | |
| total_loss.backward() | |
| opt.step() | |
| total_reconstruction_loss += reconstruction_loss.cpu().item() | |
| total_popularity_loss += popularity_loss.cpu().item() | |
| total_evaluation_loss += evaluation_loss.cpu().item() | |
| total_batch_loss += total_loss.cpu().item() | |
| num_batches += 1 | |
| print(f"Average reconstruction loss: {total_reconstruction_loss/num_batches}") | |
| print(f"Average popularity loss: {total_popularity_loss/num_batches}") | |
| print(f"Average evaluation loss: {total_evaluation_loss/num_batches}") | |
| print(f"Average batch loss: {total_batch_loss/num_batches}") | |
| torch.save(model, f"checkpoints/epoch_{epoch}.pth") | |
| def infer(fen): | |
| pass | |
| def test(): | |
| pass | |
| if __name__ == "__main__": | |
| if len(sys.argv) < 2: | |
| print("Usage: python {sys.argv[0]} --train|infer") | |
| elif sys.argv[1] == "--train": | |
| train() | |
| elif sys.argv[2] == "--infer": | |
| infer(sys.argv[3]) | |