improve GPU logging to break out pytorch cache and system mem
Browse files- scripts/finetune.py +0 -3
- src/axolotl/utils/bench.py +20 -3
- src/axolotl/utils/callbacks.py +3 -3
- src/axolotl/utils/config.py +4 -0
- src/axolotl/utils/models.py +3 -3
- src/axolotl/utils/trainer.py +2 -2
scripts/finetune.py
CHANGED
|
@@ -18,7 +18,6 @@ from optimum.bettertransformer import BetterTransformer
|
|
| 18 |
from transformers import GenerationConfig, TextStreamer
|
| 19 |
|
| 20 |
from axolotl.logging_config import configure_logging
|
| 21 |
-
from axolotl.utils.bench import log_gpu_memory_usage
|
| 22 |
from axolotl.utils.config import normalize_config, validate_config
|
| 23 |
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
| 24 |
from axolotl.utils.dict import DictDefault
|
|
@@ -226,8 +225,6 @@ def train(
|
|
| 226 |
LOG.info("Finished preparing dataset. Exiting...")
|
| 227 |
return
|
| 228 |
|
| 229 |
-
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
| 230 |
-
|
| 231 |
# Load the model and tokenizer
|
| 232 |
LOG.info("loading model and (optionally) peft_config...")
|
| 233 |
model, peft_config = load_model(cfg, tokenizer)
|
|
|
|
| 18 |
from transformers import GenerationConfig, TextStreamer
|
| 19 |
|
| 20 |
from axolotl.logging_config import configure_logging
|
|
|
|
| 21 |
from axolotl.utils.config import normalize_config, validate_config
|
| 22 |
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
| 23 |
from axolotl.utils.dict import DictDefault
|
|
|
|
| 225 |
LOG.info("Finished preparing dataset. Exiting...")
|
| 226 |
return
|
| 227 |
|
|
|
|
|
|
|
| 228 |
# Load the model and tokenizer
|
| 229 |
LOG.info("loading model and (optionally) peft_config...")
|
| 230 |
model, peft_config = load_model(cfg, tokenizer)
|
src/axolotl/utils/bench.py
CHANGED
|
@@ -4,13 +4,23 @@ import pynvml
|
|
| 4 |
import torch
|
| 5 |
|
| 6 |
|
| 7 |
-
def gpu_memory_usage(device):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
if isinstance(device, torch.device):
|
| 9 |
device = device.index
|
| 10 |
if isinstance(device, str) and device.startswith("cuda:"):
|
| 11 |
device = int(device[5:])
|
| 12 |
|
| 13 |
-
# NB torch.cuda.memory_usage returns zero so we use lower level api
|
| 14 |
pynvml.nvmlInit()
|
| 15 |
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
| 16 |
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
@@ -18,6 +28,13 @@ def gpu_memory_usage(device):
|
|
| 18 |
|
| 19 |
|
| 20 |
def log_gpu_memory_usage(log, msg, device):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
log.info(
|
| 22 |
-
f"GPU memory usage {msg}: {
|
| 23 |
)
|
|
|
|
|
|
| 4 |
import torch
|
| 5 |
|
| 6 |
|
| 7 |
+
def gpu_memory_usage(device=0):
|
| 8 |
+
return torch.cuda.memory_allocated(device) / 1024.0**3
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def gpu_memory_usage_all(device=0):
|
| 12 |
+
usage = torch.cuda.memory_allocated(device) / 1024.0**3
|
| 13 |
+
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
|
| 14 |
+
smi = gpu_memory_usage_smi(device)
|
| 15 |
+
return usage, reserved - usage, max(0, smi - reserved)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def gpu_memory_usage_smi(device=0):
|
| 19 |
if isinstance(device, torch.device):
|
| 20 |
device = device.index
|
| 21 |
if isinstance(device, str) and device.startswith("cuda:"):
|
| 22 |
device = int(device[5:])
|
| 23 |
|
|
|
|
| 24 |
pynvml.nvmlInit()
|
| 25 |
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
| 26 |
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
def log_gpu_memory_usage(log, msg, device):
|
| 31 |
+
usage, cache, misc = gpu_memory_usage_all(device)
|
| 32 |
+
extras = []
|
| 33 |
+
if cache > 0:
|
| 34 |
+
extras.append(f"+{cache:.03f}GB cache")
|
| 35 |
+
if misc > 0:
|
| 36 |
+
extras.append(f"+{misc:.03f}GB misc")
|
| 37 |
log.info(
|
| 38 |
+
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
|
| 39 |
)
|
| 40 |
+
return usage, cache, misc
|
src/axolotl/utils/callbacks.py
CHANGED
|
@@ -74,10 +74,10 @@ class SaveBetterTransformerModelCallback(
|
|
| 74 |
return control
|
| 75 |
|
| 76 |
|
| 77 |
-
class
|
| 78 |
TrainerCallback
|
| 79 |
): # pylint: disable=too-few-public-methods disable=unused-argument
|
| 80 |
-
"""Callback to
|
| 81 |
|
| 82 |
def __init__(self, cfg):
|
| 83 |
self.cfg = cfg
|
|
@@ -90,7 +90,7 @@ class PrintGPUStatsCallback(
|
|
| 90 |
control: TrainerControl,
|
| 91 |
**kwargs,
|
| 92 |
):
|
| 93 |
-
if not self.logged:
|
| 94 |
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
| 95 |
self.logged = True
|
| 96 |
return control
|
|
|
|
| 74 |
return control
|
| 75 |
|
| 76 |
|
| 77 |
+
class GPUStatsCallback(
|
| 78 |
TrainerCallback
|
| 79 |
): # pylint: disable=too-few-public-methods disable=unused-argument
|
| 80 |
+
"""Callback to track GPU utilization"""
|
| 81 |
|
| 82 |
def __init__(self, cfg):
|
| 83 |
self.cfg = cfg
|
|
|
|
| 90 |
control: TrainerControl,
|
| 91 |
**kwargs,
|
| 92 |
):
|
| 93 |
+
if not self.logged and state.global_step > 1:
|
| 94 |
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
| 95 |
self.logged = True
|
| 96 |
return control
|
src/axolotl/utils/config.py
CHANGED
|
@@ -5,6 +5,8 @@ import os
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
|
|
|
|
|
|
|
| 8 |
LOG = logging.getLogger("axolotl")
|
| 9 |
|
| 10 |
|
|
@@ -54,6 +56,8 @@ def normalize_config(cfg):
|
|
| 54 |
else:
|
| 55 |
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
| 56 |
|
|
|
|
|
|
|
| 57 |
|
| 58 |
def validate_config(cfg):
|
| 59 |
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
|
| 8 |
+
from axolotl.utils.bench import log_gpu_memory_usage
|
| 9 |
+
|
| 10 |
LOG = logging.getLogger("axolotl")
|
| 11 |
|
| 12 |
|
|
|
|
| 56 |
else:
|
| 57 |
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
| 58 |
|
| 59 |
+
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
| 60 |
+
|
| 61 |
|
| 62 |
def validate_config(cfg):
|
| 63 |
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
src/axolotl/utils/models.py
CHANGED
|
@@ -381,9 +381,6 @@ def load_model(
|
|
| 381 |
module.scales = module.scales.half()
|
| 382 |
module.bias = module.bias.half()
|
| 383 |
|
| 384 |
-
if model.device.type == "cuda":
|
| 385 |
-
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
| 386 |
-
|
| 387 |
if (
|
| 388 |
torch.cuda.device_count() > 1
|
| 389 |
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
|
@@ -406,6 +403,9 @@ def load_model(
|
|
| 406 |
if cfg.flash_optimum:
|
| 407 |
model = BetterTransformer.transform(model)
|
| 408 |
|
|
|
|
|
|
|
|
|
|
| 409 |
# TODO resume_from_checkpoint handling
|
| 410 |
return model, lora_config
|
| 411 |
|
|
|
|
| 381 |
module.scales = module.scales.half()
|
| 382 |
module.bias = module.bias.half()
|
| 383 |
|
|
|
|
|
|
|
|
|
|
| 384 |
if (
|
| 385 |
torch.cuda.device_count() > 1
|
| 386 |
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
|
|
|
| 403 |
if cfg.flash_optimum:
|
| 404 |
model = BetterTransformer.transform(model)
|
| 405 |
|
| 406 |
+
if cfg.adapter is not None:
|
| 407 |
+
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
| 408 |
+
|
| 409 |
# TODO resume_from_checkpoint handling
|
| 410 |
return model, lora_config
|
| 411 |
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -22,7 +22,7 @@ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
|
| 22 |
from transformers.trainer_pt_utils import get_parameter_names
|
| 23 |
|
| 24 |
from axolotl.utils.callbacks import (
|
| 25 |
-
|
| 26 |
SaveBetterTransformerModelCallback,
|
| 27 |
SavePeftModelCallback,
|
| 28 |
)
|
|
@@ -555,7 +555,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
| 555 |
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
| 556 |
|
| 557 |
callbacks = []
|
| 558 |
-
callbacks.append(
|
| 559 |
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
| 560 |
if cfg.early_stopping_patience:
|
| 561 |
early_stop_cb = EarlyStoppingCallback(
|
|
|
|
| 22 |
from transformers.trainer_pt_utils import get_parameter_names
|
| 23 |
|
| 24 |
from axolotl.utils.callbacks import (
|
| 25 |
+
GPUStatsCallback,
|
| 26 |
SaveBetterTransformerModelCallback,
|
| 27 |
SavePeftModelCallback,
|
| 28 |
)
|
|
|
|
| 555 |
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
| 556 |
|
| 557 |
callbacks = []
|
| 558 |
+
callbacks.append(GPUStatsCallback(cfg))
|
| 559 |
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
| 560 |
if cfg.early_stopping_patience:
|
| 561 |
early_stop_cb = EarlyStoppingCallback(
|