use recommended setting for use_reentrant w gradient checkpointing (#1021)
Browse files* use recommended setting for use_reentrant w gradient checkpointing
* add doc for gradient_checkpointing_kwargs
- README.md +3 -0
- src/axolotl/core/trainer_builder.py +8 -0
README.md
CHANGED
|
@@ -741,6 +741,9 @@ group_by_length: false
|
|
| 741 |
|
| 742 |
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
| 743 |
gradient_checkpointing: false
|
|
|
|
|
|
|
|
|
|
| 744 |
|
| 745 |
# Stop training after this many evaluation losses have increased in a row
|
| 746 |
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
|
|
|
| 741 |
|
| 742 |
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
| 743 |
gradient_checkpointing: false
|
| 744 |
+
# additional kwargs to pass to the trainer for gradient checkpointing
|
| 745 |
+
# gradient_checkpointing_kwargs:
|
| 746 |
+
# use_reentrant: false
|
| 747 |
|
| 748 |
# Stop training after this many evaluation losses have increased in a row
|
| 749 |
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -566,6 +566,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 566 |
training_arguments_kwargs[
|
| 567 |
"gradient_checkpointing"
|
| 568 |
] = self.cfg.gradient_checkpointing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
if self.cfg.fsdp:
|
| 570 |
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
| 571 |
if self.cfg.fsdp_config:
|
|
|
|
| 566 |
training_arguments_kwargs[
|
| 567 |
"gradient_checkpointing"
|
| 568 |
] = self.cfg.gradient_checkpointing
|
| 569 |
+
if self.cfg.gradient_checkpointing_kwargs:
|
| 570 |
+
training_arguments_kwargs[
|
| 571 |
+
"gradient_checkpointing_kwargs"
|
| 572 |
+
] = self.cfg.gradient_checkpointing_kwargs
|
| 573 |
+
else:
|
| 574 |
+
training_arguments_kwargs["gradient_checkpointing_kwargs"] = {
|
| 575 |
+
"use_reentrant": False
|
| 576 |
+
}
|
| 577 |
if self.cfg.fsdp:
|
| 578 |
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
| 579 |
if self.cfg.fsdp_config:
|