from vram_helpers import activations_memory_per_layer, \ model_memory, \ gradients_memory, \ optimizer_memory, \ activations_memory, \ kv_cache_memory def training_vram_required(model_config, training_config): # Reference: https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/ trainable_parameters = model_config.model_size if training_config.qlora: model_config.precision = "int4" # 0.2% according to LoRA paper (https://arxiv.org/pdf/2106.09685) trainable_parameters = 0.0002 * model_config.model_size model_vram = model_memory(parameters=trainable_parameters, precision=model_config.precision, mixed_precision=model_config.mixed_precision) gradients_vram = gradients_memory(parameters=trainable_parameters) optimizer_vram = optimizer_memory(parameters=trainable_parameters, optimizer=training_config.optimizer) # Baseline if training_config.zero_stage == 0: pass # Optimizer state partitioning if training_config.zero_stage >= 1: optimizer_vram = optimizer_vram / training_config.num_gpus # Gradient + Optimzer state partitioning if training_config.zero_stage >= 2: gradients_vram = gradients_vram / training_config.num_gpus # Parameter partitioning + Gradient + Optimizer partitioning if training_config.zero_stage == 3: aggregated_vram = model_vram / training_config.num_gpus aggregated_vram = model_vram + gradients_vram + optimizer_vram activations_vram = activations_memory(model_config.num_layers, model_config.sequence_length, training_config.micro_batch_size, model_config.hidden_size, model_config.num_heads) if training_config.gradient_checkpointing: activations_vram = round(activations_vram ** 0.5, 2) total_vram = aggregated_vram + activations_vram return {k: round(v, 2) for k, v in { "total": total_vram, "model": model_vram, "gradients": gradients_vram, "optimizer": optimizer_vram, "activations": activations_vram }.items()} def inference_vram_required(model_config, training_config): model_config.mixed_precision = False # Total inference VRAM = model size + KV cache size + activations + additional overhead model_vram = model_memory(parameters=model_config.model_size, precision=model_config.precision, mixed_precision=model_config.mixed_precision) kv_cache_vram = kv_cache_memory(batch_size=training_config.micro_batch_size, total_sequence_length=model_config.total_sequence_length, num_layers=model_config.num_layers, num_heads=model_config.num_heads, hidden_size=model_config.hidden_size, precision=model_config.precision) activations_vram = activations_memory_per_layer(sequence_length=model_config.sequence_length, micro_batch_size=training_config.micro_batch_size, hidden_size=model_config.hidden_size, num_heads=model_config.num_heads) total_vram = model_vram + kv_cache_vram + activations_vram return {k: round(v, 2) for k, v in { "total": total_vram, "model": model_vram, "kv_cache": kv_cache_vram, "activations": activations_vram }.items()}