use recommended setting for use_reentrant w gradient checkpointing (#1021)
Browse files* use recommended setting for use_reentrant w gradient checkpointing
* add doc for gradient_checkpointing_kwargs
- README.md +3 -0
- src/axolotl/core/trainer_builder.py +8 -0
    	
        README.md
    CHANGED
    
    | @@ -741,6 +741,9 @@ group_by_length: false | |
| 741 |  | 
| 742 | 
             
            # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
         | 
| 743 | 
             
            gradient_checkpointing: false
         | 
|  | |
|  | |
|  | |
| 744 |  | 
| 745 | 
             
            # Stop training after this many evaluation losses have increased in a row
         | 
| 746 | 
             
            # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
         | 
|  | |
| 741 |  | 
| 742 | 
             
            # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
         | 
| 743 | 
             
            gradient_checkpointing: false
         | 
| 744 | 
            +
            # additional kwargs to pass to the trainer for gradient checkpointing
         | 
| 745 | 
            +
            # gradient_checkpointing_kwargs:
         | 
| 746 | 
            +
            #   use_reentrant: false
         | 
| 747 |  | 
| 748 | 
             
            # Stop training after this many evaluation losses have increased in a row
         | 
| 749 | 
             
            # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
         | 
    	
        src/axolotl/core/trainer_builder.py
    CHANGED
    
    | @@ -566,6 +566,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): | |
| 566 | 
             
                        training_arguments_kwargs[
         | 
| 567 | 
             
                            "gradient_checkpointing"
         | 
| 568 | 
             
                        ] = self.cfg.gradient_checkpointing
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 569 | 
             
                    if self.cfg.fsdp:
         | 
| 570 | 
             
                        training_arguments_kwargs["fsdp"] = self.cfg.fsdp
         | 
| 571 | 
             
                        if self.cfg.fsdp_config:
         | 
|  | |
| 566 | 
             
                        training_arguments_kwargs[
         | 
| 567 | 
             
                            "gradient_checkpointing"
         | 
| 568 | 
             
                        ] = self.cfg.gradient_checkpointing
         | 
| 569 | 
            +
                        if self.cfg.gradient_checkpointing_kwargs:
         | 
| 570 | 
            +
                            training_arguments_kwargs[
         | 
| 571 | 
            +
                                "gradient_checkpointing_kwargs"
         | 
| 572 | 
            +
                            ] = self.cfg.gradient_checkpointing_kwargs
         | 
| 573 | 
            +
                        else:
         | 
| 574 | 
            +
                            training_arguments_kwargs["gradient_checkpointing_kwargs"] = {
         | 
| 575 | 
            +
                                "use_reentrant": False
         | 
| 576 | 
            +
                            }
         | 
| 577 | 
             
                    if self.cfg.fsdp:
         | 
| 578 | 
             
                        training_arguments_kwargs["fsdp"] = self.cfg.fsdp
         | 
| 579 | 
             
                        if self.cfg.fsdp_config:
         | 
