""" Training script for baseline NanoGPT model on enwik8 dataset. Ensures proper bpc calculation and comparable evaluation with DTAT. """ import os import time import math import numpy as np import torch import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed import init_process_group, destroy_process_group from contextlib import nullcontext import wandb from tqdm import tqdm from model_baseline import BaselineTransformer from config.baseline_config import get_config def get_batch(data, block_size, batch_size, device): """Generate a small batch of data of inputs x and targets y.""" ix = torch.randint(len(data) - block_size, (batch_size,)) x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) x, y = x.to(device), y.to(device) return x, y def estimate_loss(model, data, config): """Estimate loss on data split, ensuring proper bpc calculation.""" model.eval() losses = torch.zeros(config.eval_iters) for k in range(config.eval_iters): X, Y = get_batch(data, config.block_size, config.batch_size, config.device) with torch.no_grad(): logits, loss = model(X, Y) losses[k] = loss.item() # Loss is already in BPC out = losses.mean() model.train() return out def get_lr(it, config): """Get learning rate based on iteration.""" if it < config.warmup_iters: return config.learning_rate * it / config.warmup_iters if it > config.lr_decay_iters: return config.min_lr decay_ratio = (it - config.warmup_iters) / (config.lr_decay_iters - config.warmup_iters) coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) return config.min_lr + coeff * (config.learning_rate - config.min_lr) def main(): # Initialize config config = get_config() # Initialize wandb wandb.init(project='enwik8-baseline', config=vars(config)) # Load dataset data = np.memmap('data/train.bin', dtype=np.uint8, mode='r') val_data = np.memmap('data/val.bin', dtype=np.uint8, mode='r') # Initialize model model = BaselineTransformer(config) model.to(config.device) # Initialize optimizer optimizer = torch.optim.AdamW( model.parameters(), lr=config.learning_rate, betas=(config.beta1, config.beta2), weight_decay=config.weight_decay ) if config.compile: print("Compiling model...") model = torch.compile(model) # Enable mixed precision training scaler = torch.cuda.amp.GradScaler(enabled=config.mixed_precision) # Enable cuDNN benchmarking torch.backends.cudnn.benchmark = True # Calculate total steps and epochs total_steps = config.max_iters batch_size = config.batch_size block_size = config.block_size total_epochs = (total_steps * batch_size * block_size) // len(data) print(f"Training baseline model for {total_epochs} epochs ({total_steps} iterations)") # Create progress bar pbar = tqdm(range(config.max_iters), desc=f"Training (0/{total_epochs} epochs)") best_val_loss = float('inf') t0 = time.time() for iter_num in pbar: # Update learning rate lr = get_lr(iter_num, config) for param_group in optimizer.param_groups: param_group['lr'] = lr # Sample batch X, Y = get_batch(data, config.block_size, config.batch_size, config.device) # Forward pass with mixed precision with torch.cuda.amp.autocast(enabled=config.mixed_precision): logits, loss = model(X, Y) # Backward pass with gradient scaling optimizer.zero_grad(set_to_none=True) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) scaler.step(optimizer) scaler.update() # Logging if iter_num % config.log_interval == 0: # Calculate current epoch current_tokens = (iter_num + 1) * batch_size * block_size current_epoch = current_tokens / len(data) # Update progress bar pbar.set_description( f"Training ({current_epoch:.1f}/{total_epochs} epochs) | " f"loss: {loss.item():.4f} | " # Already in BPC f"lr: {lr:.1e} | " f"tokens/sec: {(batch_size * block_size) / (time.time() - t0):.1f}" ) # Log to wandb wandb.log({ "iter": iter_num, "loss": loss.item(), "bpc": loss.item(), # Already in BPC "lr": lr, "epoch": current_epoch, "tokens_per_sec": (batch_size * block_size) / (time.time() - t0), }) t0 = time.time() # Evaluation if iter_num > 0 and iter_num % config.eval_interval == 0: val_loss = estimate_loss(model, val_data, config) wandb.log({ "val_loss": val_loss, "val_bpc": val_loss, # Already in BPC "epoch": current_epoch, }) # Save best model if val_loss < best_val_loss: best_val_loss = val_loss print(f"Saving best model with val_bpc: {val_loss:.4f}") torch.save(model.state_dict(), 'models/baseline_best.pt') # Final evaluation model.eval() final_val_loss = estimate_loss(model, val_data, config) print(f"Final validation BPC: {final_val_loss:.4f}") # Save final model torch.save(model.state_dict(), 'models/baseline_final.pt') wandb.finish() if __name__ == '__main__': main()