Fix(config): Update handling of deepspeed config (#404)
Browse files* Fix(config): Update handling of deepspeed config
* feat: auto set deepspeed env if deepspeed passed
* fix: update new deepspeed instructions
- README.md +5 -2
- src/axolotl/utils/config.py +1 -1
- src/axolotl/utils/trainer.py +7 -12
README.md
CHANGED
|
@@ -519,7 +519,7 @@ tokens:
|
|
| 519 |
fsdp:
|
| 520 |
fsdp_config:
|
| 521 |
|
| 522 |
-
# Deepspeed
|
| 523 |
deepspeed:
|
| 524 |
|
| 525 |
# Path to torch distx for optim 'adamw_anyprecision'
|
|
@@ -570,7 +570,10 @@ fsdp_config:
|
|
| 570 |
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
| 571 |
```
|
| 572 |
|
| 573 |
-
- llama Deepspeed
|
|
|
|
|
|
|
|
|
|
| 574 |
|
| 575 |
##### Weights & Biases Logging
|
| 576 |
|
|
|
|
| 519 |
fsdp:
|
| 520 |
fsdp_config:
|
| 521 |
|
| 522 |
+
# Deepspeed config path
|
| 523 |
deepspeed:
|
| 524 |
|
| 525 |
# Path to torch distx for optim 'adamw_anyprecision'
|
|
|
|
| 570 |
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
| 571 |
```
|
| 572 |
|
| 573 |
+
- llama Deepspeed
|
| 574 |
+
```yaml
|
| 575 |
+
deepspeed: # path to config
|
| 576 |
+
```
|
| 577 |
|
| 578 |
##### Weights & Biases Logging
|
| 579 |
|
src/axolotl/utils/config.py
CHANGED
|
@@ -147,7 +147,7 @@ def validate_config(cfg):
|
|
| 147 |
"You should probably set bfloat16 or float16 to true to "
|
| 148 |
"load the model in float16 for BetterTransformers"
|
| 149 |
)
|
| 150 |
-
if int(torch.__version__.split(".")[0]) < 2:
|
| 151 |
LOG.warning("torch>=2.0.0 required")
|
| 152 |
raise ValueError(
|
| 153 |
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
|
|
|
| 147 |
"You should probably set bfloat16 or float16 to true to "
|
| 148 |
"load the model in float16 for BetterTransformers"
|
| 149 |
)
|
| 150 |
+
if int(torch.__version__.split(".", maxsplit=1)[0]) < 2:
|
| 151 |
LOG.warning("torch>=2.0.0 required")
|
| 152 |
raise ValueError(
|
| 153 |
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -364,6 +364,9 @@ def setup_fsdp_envs(cfg):
|
|
| 364 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
| 365 |
if cfg.fsdp:
|
| 366 |
setup_fsdp_envs(cfg)
|
|
|
|
|
|
|
|
|
|
| 367 |
warmup_steps = (
|
| 368 |
cfg.warmup_steps
|
| 369 |
if cfg.warmup_steps is not None
|
|
@@ -411,21 +414,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
| 411 |
if cfg.fsdp_config:
|
| 412 |
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
|
| 413 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
if cfg.lr_quadratic_warmup is not None:
|
| 415 |
training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup
|
| 416 |
|
| 417 |
-
# deepspeed
|
| 418 |
-
if (
|
| 419 |
-
os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true"
|
| 420 |
-
and torch.cuda.device_count() > 1
|
| 421 |
-
):
|
| 422 |
-
if cfg.deepspeed:
|
| 423 |
-
training_arguments_kwargs["deepspeed"] = cfg.deepspeed
|
| 424 |
-
else:
|
| 425 |
-
# make a guess here
|
| 426 |
-
# TODO search Path("./") for one
|
| 427 |
-
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
|
| 428 |
-
|
| 429 |
if cfg.adam_beta1:
|
| 430 |
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
|
| 431 |
if cfg.adam_beta2:
|
|
|
|
| 364 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
| 365 |
if cfg.fsdp:
|
| 366 |
setup_fsdp_envs(cfg)
|
| 367 |
+
elif cfg.deepspeed:
|
| 368 |
+
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
| 369 |
+
|
| 370 |
warmup_steps = (
|
| 371 |
cfg.warmup_steps
|
| 372 |
if cfg.warmup_steps is not None
|
|
|
|
| 414 |
if cfg.fsdp_config:
|
| 415 |
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
|
| 416 |
|
| 417 |
+
# deepspeed
|
| 418 |
+
if cfg.deepspeed:
|
| 419 |
+
training_arguments_kwargs["deepspeed"] = cfg.deepspeed
|
| 420 |
+
|
| 421 |
if cfg.lr_quadratic_warmup is not None:
|
| 422 |
training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup
|
| 423 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
if cfg.adam_beta1:
|
| 425 |
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
|
| 426 |
if cfg.adam_beta2:
|