Fix DeepSpeed Zero 3 Saving (#709)
Browse files* Update train.py
* add zero3 check
* chore: lint
---------
Co-authored-by: Wing Lian <[email protected]>
- src/axolotl/train.py +17 -0
src/axolotl/train.py
CHANGED
|
@@ -12,6 +12,7 @@ import torch
|
|
| 12 |
import transformers.modelcard
|
| 13 |
from datasets import Dataset
|
| 14 |
from optimum.bettertransformer import BetterTransformer
|
|
|
|
| 15 |
|
| 16 |
from axolotl.common.cli import TrainerCliArgs
|
| 17 |
from axolotl.logging_config import configure_logging
|
|
@@ -134,6 +135,22 @@ def train(
|
|
| 134 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
| 135 |
if cfg.fsdp:
|
| 136 |
trainer.save_model(cfg.output_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
elif cfg.local_rank == 0:
|
| 138 |
if cfg.flash_optimum:
|
| 139 |
model = BetterTransformer.reverse(model)
|
|
|
|
| 12 |
import transformers.modelcard
|
| 13 |
from datasets import Dataset
|
| 14 |
from optimum.bettertransformer import BetterTransformer
|
| 15 |
+
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
| 16 |
|
| 17 |
from axolotl.common.cli import TrainerCliArgs
|
| 18 |
from axolotl.logging_config import configure_logging
|
|
|
|
| 135 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
| 136 |
if cfg.fsdp:
|
| 137 |
trainer.save_model(cfg.output_dir)
|
| 138 |
+
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
| 139 |
+
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
| 140 |
+
trainer.accelerator.wait_for_everyone()
|
| 141 |
+
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
|
| 142 |
+
|
| 143 |
+
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
|
| 144 |
+
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
|
| 145 |
+
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
|
| 146 |
+
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
|
| 147 |
+
# The model name saved is `pytorch_model.bin`
|
| 148 |
+
unwrapped_model.save_pretrained(
|
| 149 |
+
cfg.output_dir,
|
| 150 |
+
is_main_process=trainer.accelerator.is_main_process,
|
| 151 |
+
save_function=trainer.accelerator.save,
|
| 152 |
+
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
|
| 153 |
+
)
|
| 154 |
elif cfg.local_rank == 0:
|
| 155 |
if cfg.flash_optimum:
|
| 156 |
model = BetterTransformer.reverse(model)
|