llm / train_dtat.py
eyad-silx's picture
Update train_dtat.py
d606189 verified
"""
Training script for Dynamic Token-Aware Transformer (DTAT) on enwik8 dataset.
Based on NanoGPT's training structure with modifications for token importance awareness.
"""
import os
import time
import math
import pickle
from contextlib import nullcontext
import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import matplotlib.pyplot as plt
import wandb
from tqdm import tqdm
from datetime import datetime
from model_dtat import DTATTransformer
from config.dtat_config import get_config
# -----------------------------------------------------------------------------
# I/O
def get_batch(data, block_size, batch_size, device):
"""Generate a small batch of data of inputs x and targets y."""
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
x, y = x.to(device), y.to(device)
return x, y
def compute_freq_table(data, vocab_size=256):
"""Compute frequency table for the dataset."""
freq = np.bincount(data, minlength=vocab_size)
return freq / len(data)
def visualize_importance(tokens, importance_scores, iter_num):
"""
Visualize token importance scores
"""
plt.figure(figsize=(15, 5))
# Detach and move to CPU before converting to numpy
scores = importance_scores.detach().squeeze().cpu()
plt.bar(range(len(tokens)), scores)
plt.title(f'Token Importance Scores (Iteration {iter_num})')
plt.xlabel('Token Position')
plt.ylabel('Importance Score')
# Add token labels if sequence is not too long
if len(tokens) <= 50:
plt.xticks(range(len(tokens)), tokens, rotation=45)
# Save plot to wandb
wandb.log({
'importance_scores': wandb.Image(plt),
'iter': iter_num
})
plt.close()
# -----------------------------------------------------------------------------
# Training
def estimate_loss(model, data, config):
out = {}
model.eval()
losses = torch.zeros(config.eval_iters)
for k in range(config.eval_iters):
X, Y = get_batch(data, config.block_size, config.batch_size, config.device)
with torch.no_grad():
logits, loss, _ = model(X, Y)
losses[k] = loss.item()
out = losses.mean()
model.train()
return out
def get_lr(it, config):
"""
Learning rate scheduler with linear warmup and cosine decay
"""
# Linear warmup
if it < config.warmup_iters:
return config.learning_rate * it / config.warmup_iters
# Cosine decay with minimum learning rate
if config.decay_lr:
decay_ratio = (it - config.warmup_iters) / (config.lr_decay_iters - config.warmup_iters)
decay_ratio = min(decay_ratio, 1.0) # Cap at 1.0
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return config.min_lr + coeff * (config.learning_rate - config.min_lr)
return config.learning_rate
def main():
# Initialize distributed training if needed
ddp = int(os.environ.get('RANK', -1)) != -1
if ddp:
init_process_group(backend='nccl')
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
device = f'cuda:{ddp_local_rank}'
master_process = ddp_rank == 0
seed_offset = ddp_rank
assert config.batch_size % torch.cuda.device_count() == 0
config.batch_size = config.batch_size // torch.cuda.device_count()
else:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
master_process = True
seed_offset = 0
# Set seed for reproducibility
torch.manual_seed(1337 + seed_offset)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device_type = 'cuda' if 'cuda' in device else 'cpu'
# Get config
config = get_config()
config.device = device
# Initialize wandb
if master_process:
wandb.init(project="enwik8-dtat")
wandb.config.update(config.__dict__)
# Adjust warmup
config.warmup_iters = 2000 # Increased warmup iterations
config.learning_rate = 6e-4 # Confirmed learning rate
# Data loading
print("Loading data...")
data_dir = os.path.join('data')
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint8, mode='r')
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint8, mode='r')
# Compute frequency table for the training data
freq_table = compute_freq_table(train_data)
# Model init
print("Initializing model...")
model = DTATTransformer(config)
model.to(device)
# Optimizer
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.learning_rate,
betas=(config.beta1, config.beta2),
weight_decay=config.weight_decay
)
if ddp:
model = DDP(model, device_ids=[ddp_local_rank])
# Enable torch compile if available (PyTorch 2.0+)
if hasattr(torch, 'compile'):
try:
model = torch.compile(model)
print("Using torch.compile() for faster training")
except:
print("torch.compile() failed, falling back to default model")
# Gradient scaler for mixed precision
scaler = torch.cuda.amp.GradScaler(enabled=config.mixed_precision)
# Enable cuDNN benchmarking for faster training
torch.backends.cudnn.benchmark = True
# Create checkpoint directory if it doesn't exist
checkpoint_dir = os.path.join('checkpoints', 'dtat')
os.makedirs(checkpoint_dir, exist_ok=True)
# Training loop
print("Starting training...")
print(f"Saving checkpoints to: {checkpoint_dir}")
# Calculate total steps and epochs
total_steps = config.max_iters
batch_size = config.batch_size
block_size = config.block_size
total_epochs = (total_steps * batch_size * block_size) // len(train_data)
# Create progress bar
pbar = tqdm(range(config.max_iters), desc=f"Training (0/{total_epochs} epochs)")
best_val_loss = float('inf')
no_improvement = 0
running_mfu = -1.0
t0 = time.time()
for iter_num in pbar:
# Early stopping check
if no_improvement >= config.patience:
print(f"\nEarly stopping triggered after {iter_num} iterations")
print(f"Best validation loss: {best_val_loss:.4f}")
break
# Update learning rate
lr = get_lr(iter_num, config)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Sample a batch of data
X, Y = get_batch(train_data, config.block_size, config.batch_size, device)
# Mixed precision training
with torch.cuda.amp.autocast(enabled=config.mixed_precision):
logits, loss, importance_scores = model(X, Y)
# Backward pass with gradient scaling
optimizer.zero_grad(set_to_none=True) # Slightly faster than zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
scaler.step(optimizer)
scaler.update()
# Logging
if iter_num % config.log_interval == 0:
# Calculate current epoch
current_tokens = (iter_num + 1) * batch_size * block_size
current_epoch = current_tokens / len(train_data)
# Calculate gradients and importance stats
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).item()
importance_mean = importance_scores.mean().item()
# Update progress bar
pbar.set_description(
f"Training ({current_epoch:.1f}/{total_epochs} epochs) | "
f"loss: {loss.item():.4f} | " # This is now directly in BPC
f"bpc: {loss.item():.2f} | " # Same as loss since it's already BPC
f"imp: {importance_mean:.2f} | "
f"lr: {lr:.1e} | "
f"tokens/sec: {(batch_size * block_size) / (time.time() - t0):.1f}"
)
# Log to wandb
wandb.log({
"iter": iter_num,
"loss": loss.item(), # This is now directly in BPC
"bpc": loss.item(), # Same as loss since it's already BPC
"lr": lr,
"grad_norm": grad_norm,
"importance_mean": importance_mean,
"epoch": current_epoch,
"tokens_per_sec": (batch_size * block_size) / (time.time() - t0),
})
# Reset timer
t0 = time.time()
# Visualize importance scores periodically
if iter_num % (config.log_interval * 10) == 0:
visualize_importance(
X[0].cpu().numpy(),
importance_scores[0],
iter_num
)
# Evaluation
if iter_num > 0 and iter_num % config.eval_interval == 0:
val_loss = estimate_loss(model, val_data, config)
# Check for improvement
if val_loss < best_val_loss - config.min_delta:
best_val_loss = val_loss
no_improvement = 0
print(f"Saved best model at iteration {iter_num} with val_loss: {val_loss:.4f}")
torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'best.pt'))
else:
no_improvement += 1
# Log validation metrics
wandb.log({
"iter": iter_num,
"val_loss": val_loss,
"val_bpc": val_loss,
"epoch": current_epoch,
})
# Save regular checkpoint every 5000 iterations
if iter_num % 1000 == 0:
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'iter_num': iter_num,
'best_val_loss': best_val_loss,
'config': config,
}
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_{iter_num:06d}.pt')
torch.save(checkpoint, checkpoint_path)
print(f"\nSaved checkpoint at iteration {iter_num} to {checkpoint_path}")
wandb.finish()
if __name__ == '__main__':
main()