fdsp config dict fix, todo list, add torchdistx support
Browse files- TODO.md +10 -0
- src/axolotl/utils/models.py +5 -0
- src/axolotl/utils/trainer.py +9 -3
TODO.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# todo list
|
| 2 |
+
|
| 3 |
+
- [] Validation of parameters for combinations that won't work
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
## things that are known not to work
|
| 8 |
+
|
| 9 |
+
- FSDP offload and gradient_checkpointing - https://github.com/pytorch/pytorch/issues/82203
|
| 10 |
+
- adamw_bnb_8bit doesn't play well with FSDP offload
|
src/axolotl/utils/models.py
CHANGED
|
@@ -179,6 +179,11 @@ def load_model(
|
|
| 179 |
m.scales = m.scales.half()
|
| 180 |
m.bias = m.bias.half()
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
# TODO resume_from_checkpoint handling
|
| 183 |
return model, tokenizer, lora_config
|
| 184 |
|
|
|
|
| 179 |
m.scales = m.scales.half()
|
| 180 |
m.bias = m.bias.half()
|
| 181 |
|
| 182 |
+
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
| 183 |
+
model.is_parallelizable = True
|
| 184 |
+
model.model_parallel = True
|
| 185 |
+
|
| 186 |
+
|
| 187 |
# TODO resume_from_checkpoint handling
|
| 188 |
return model, tokenizer, lora_config
|
| 189 |
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
|
|
| 1 |
import math
|
| 2 |
import os
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
import bitsandbytes as bnb
|
|
@@ -35,9 +37,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 35 |
else:
|
| 36 |
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
| 37 |
if cfg.fsdp:
|
| 38 |
-
training_arguments_kwargs["fsdp"] = cfg.fsdp
|
| 39 |
-
if cfg.
|
| 40 |
-
training_arguments_kwargs["
|
| 41 |
|
| 42 |
|
| 43 |
# deepspeed
|
|
@@ -73,6 +75,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
| 73 |
|
| 74 |
trainer_kwargs = {}
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
if cfg.optimizer == "adam8bit" and not cfg.load_4bit and not "deepspeed" in training_arguments_kwargs:
|
| 77 |
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
| 78 |
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
import math
|
| 3 |
import os
|
| 4 |
+
import sys
|
| 5 |
from pathlib import Path
|
| 6 |
|
| 7 |
import bitsandbytes as bnb
|
|
|
|
| 37 |
else:
|
| 38 |
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
| 39 |
if cfg.fsdp:
|
| 40 |
+
training_arguments_kwargs["fsdp"] = cfg.fsdp
|
| 41 |
+
if cfg.fsdp_config:
|
| 42 |
+
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
|
| 43 |
|
| 44 |
|
| 45 |
# deepspeed
|
|
|
|
| 75 |
|
| 76 |
trainer_kwargs = {}
|
| 77 |
|
| 78 |
+
if cfg.optimizer == "adamw_anyprecision":
|
| 79 |
+
if Path(cfg.torchdistx_path).exists():
|
| 80 |
+
sys.path.append(cfg.torchdistx_path)
|
| 81 |
+
torchdistx = importlib.import_module('torchdistx')
|
| 82 |
if cfg.optimizer == "adam8bit" and not cfg.load_4bit and not "deepspeed" in training_arguments_kwargs:
|
| 83 |
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
| 84 |
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|