import os import argparse import torch import torch.nn.functional as F from torch.utils.data import DataLoader, DistributedSampler from model.gpt_model import GPTModel from data.dataset import TextDataset from data import utils try: import deepspeed except ImportError: deepspeed = None try: from torch.nn.parallel import DistributedDataParallel as DDP except ImportError: DDP = None def main(): parser = argparse.ArgumentParser(description="Train the OpenGPT model.") parser.add_argument("--config", type=str, required=True, help="Path to configuration file (YAML/JSON).") parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training.") args = parser.parse_args() # Load configuration config = utils.load_config(args.config) model_conf = config.get("model", {}) train_conf = config.get("training", {}) data_conf = config.get("data", {}) # Distributed setup local_rank = args.local_rank if local_rank == -1: local_rank = int(os.environ.get("LOCAL_RANK", 0)) distributed = False if "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1: distributed = True torch.distributed.init_process_group(backend="nccl", init_method="env://") device = torch.device("cuda", local_rank) if torch.cuda.is_available() else torch.device("cpu") if torch.cuda.is_available(): torch.cuda.set_device(local_rank) # Set random seed for reproducibility utils.set_seed(train_conf.get("seed", 42)) # Prepare dataset and dataloader train_dataset = TextDataset(data_conf["train_path"], data_conf["tokenizer_path"], data_conf.get("block_size", 128)) train_sampler = DistributedSampler(train_dataset) if distributed else None train_loader = DataLoader(train_dataset, batch_size=train_conf.get("batch_size", 1), sampler=train_sampler, shuffle=(train_sampler is None)) # Initialize model model = GPTModel(vocab_size=model_conf["vocab_size"], max_position_embeddings=model_conf.get("max_position_embeddings", 512), n_layers=model_conf.get("n_layers", 12), n_heads=model_conf.get("n_heads", 12), hidden_dim=model_conf.get("embedding_dim", 768), dropout=model_conf.get("dropout", 0.1)).to(device) # Optionally load a pre-trained checkpoint to fine-tune init_checkpoint = train_conf.get("init_checkpoint", "") if init_checkpoint: utils.load_checkpoint(model, optimizer=None, filepath=init_checkpoint, device=device) # Create optimizer (AdamW by default) optimizer = torch.optim.AdamW(model.parameters(), lr=train_conf.get("learning_rate", 5e-4), weight_decay=train_conf.get("weight_decay", 0.0)) # Mixed precision training mixed_precision = train_conf.get("mixed_precision", False) and torch.cuda.is_available() scaler = torch.cuda.amp.GradScaler() if mixed_precision else None # Initialize DeepSpeed if enabled use_deepspeed = False ds_config_path = train_conf.get("deepspeed_config", None) if ds_config_path and deepspeed is not None: use_deepspeed = True model, optimizer, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=ds_config_path) # If using DDP (and not DeepSpeed), wrap the model if distributed and not use_deepspeed and DDP is not None: model = DDP(model, device_ids=[local_rank]) # Training loop epochs = train_conf.get("epochs", 1) for epoch in range(epochs): if distributed and train_sampler: train_sampler.set_epoch(epoch) model.train() total_loss = 0.0 for batch_idx, (inputs, targets) in enumerate(train_loader): inputs = inputs.to(device) targets = targets.to(device) if mixed_precision: with torch.cuda.amp.autocast(): outputs = model(inputs) if not use_deepspeed else model(inputs) loss = F.cross_entropy(outputs.view(-1, model_conf["vocab_size"]), targets.view(-1)) if use_deepspeed: model.backward(loss) model.step() else: scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad() else: outputs = model(inputs) if not use_deepspeed else model(inputs) loss = F.cross_entropy(outputs.view(-1, model_conf["vocab_size"]), targets.view(-1)) if use_deepspeed: model.backward(loss) model.step() else: loss.backward() optimizer.step() optimizer.zero_grad() total_loss += loss.item() # Print progress occasionally (only on rank 0 if distributed) if batch_idx % 100 == 0 and (not distributed or torch.distributed.get_rank() == 0): avg_loss = total_loss / (batch_idx + 1) print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}, Avg Loss: {avg_loss:.4f}") # Save checkpoint at epoch end (only on rank 0) if (not distributed or torch.distributed.get_rank() == 0): ckpt_dir = train_conf.get("checkpoint_dir", "checkpoints") os.makedirs(ckpt_dir, exist_ok=True) if use_deepspeed: model.save_checkpoint(ckpt_dir, tag=f"epoch-{epoch+1}") else: ckpt_path = os.path.join(ckpt_dir, f"epoch{epoch+1}.pt") utils.save_checkpoint(model, optimizer, ckpt_path) print(f"Checkpoint saved: {ckpt_path}") if __name__ == "__main__": main()