fix FSDP save of final model (#329)
Browse files- scripts/finetune.py +3 -1
scripts/finetune.py
CHANGED
|
@@ -344,7 +344,9 @@ def train(
|
|
| 344 |
|
| 345 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
| 346 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
| 347 |
-
if cfg.
|
|
|
|
|
|
|
| 348 |
if cfg.flash_optimum:
|
| 349 |
model = BetterTransformer.reverse(model)
|
| 350 |
model.save_pretrained(cfg.output_dir)
|
|
|
|
| 344 |
|
| 345 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
| 346 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
| 347 |
+
if cfg.fsdp:
|
| 348 |
+
model.save_pretrained(cfg.output_dir)
|
| 349 |
+
elif cfg.local_rank == 0:
|
| 350 |
if cfg.flash_optimum:
|
| 351 |
model = BetterTransformer.reverse(model)
|
| 352 |
model.save_pretrained(cfg.output_dir)
|