Unsloth gradient checkpointing offload (#1528)
Browse files* unsloth gradient checkpointing
* fix validation too
* fixes to make it work with mistral
* monkeypatch the checkpoint fn earlier
    	
        src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
    CHANGED
    
    | @@ -516,24 +516,18 @@ def mistral_model_forward( | |
| 516 | 
             
                    past_key_value = past_key_values[idx] if past_key_values is not None else None
         | 
| 517 |  | 
| 518 | 
             
                    if self.gradient_checkpointing and self.training:
         | 
| 519 | 
            -
             | 
| 520 | 
            -
             | 
| 521 | 
            -
             | 
| 522 | 
            -
                                 | 
| 523 | 
            -
                                 | 
| 524 | 
            -
             | 
| 525 | 
            -
             | 
| 526 | 
            -
             | 
| 527 | 
            -
             | 
| 528 | 
            -
             | 
| 529 | 
            -
             | 
| 530 | 
            -
                             | 
| 531 | 
            -
                            position_ids,
         | 
| 532 | 
            -
                            past_key_value,
         | 
| 533 | 
            -
                            output_attentions,
         | 
| 534 | 
            -
                            None,
         | 
| 535 | 
            -
                            cu_seqlens,
         | 
| 536 | 
            -
                            max_seqlen,
         | 
| 537 | 
             
                        )
         | 
| 538 | 
             
                    else:
         | 
| 539 | 
             
                        layer_outputs = decoder_layer(
         | 
|  | |
| 516 | 
             
                    past_key_value = past_key_values[idx] if past_key_values is not None else None
         | 
| 517 |  | 
| 518 | 
             
                    if self.gradient_checkpointing and self.training:
         | 
| 519 | 
            +
                        layer_outputs = (
         | 
| 520 | 
            +
                            self._gradient_checkpointing_func(  # pylint: disable=protected-access
         | 
| 521 | 
            +
                                decoder_layer.__call__,
         | 
| 522 | 
            +
                                hidden_states,
         | 
| 523 | 
            +
                                attention_mask,
         | 
| 524 | 
            +
                                position_ids,
         | 
| 525 | 
            +
                                past_key_value,
         | 
| 526 | 
            +
                                output_attentions,
         | 
| 527 | 
            +
                                None,
         | 
| 528 | 
            +
                                cu_seqlens,
         | 
| 529 | 
            +
                                max_seqlen,
         | 
| 530 | 
            +
                            )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 531 | 
             
                        )
         | 
| 532 | 
             
                    else:
         | 
| 533 | 
             
                        layer_outputs = decoder_layer(
         | 
    	
        src/axolotl/utils/config/models/input/v0_4_1/__init__.py
    CHANGED
    
    | @@ -479,6 +479,7 @@ class AxolotlInputConfig( | |
| 479 | 
             
                eval_causal_lm_metrics: Optional[List[str]] = None
         | 
| 480 | 
             
                do_bench_eval: Optional[bool] = None
         | 
| 481 | 
             
                bench_dataset: Optional[str] = None
         | 
|  | |
| 482 | 
             
                metric_for_best_model: Optional[str] = None
         | 
| 483 | 
             
                greater_is_better: Optional[bool] = None
         | 
| 484 |  | 
| @@ -494,7 +495,9 @@ class AxolotlInputConfig( | |
| 494 |  | 
| 495 | 
             
                # torch_dtype: Optional[torch.dtype]
         | 
| 496 |  | 
| 497 | 
            -
                gradient_checkpointing: Optional[bool] = Field( | 
|  | |
|  | |
| 498 | 
             
                gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
         | 
| 499 |  | 
| 500 | 
             
                unfrozen_parameters: Optional[List[str]] = None
         | 
|  | |
| 479 | 
             
                eval_causal_lm_metrics: Optional[List[str]] = None
         | 
| 480 | 
             
                do_bench_eval: Optional[bool] = None
         | 
| 481 | 
             
                bench_dataset: Optional[str] = None
         | 
| 482 | 
            +
                bench_split: Optional[str] = None
         | 
| 483 | 
             
                metric_for_best_model: Optional[str] = None
         | 
| 484 | 
             
                greater_is_better: Optional[bool] = None
         | 
| 485 |  | 
|  | |
| 495 |  | 
| 496 | 
             
                # torch_dtype: Optional[torch.dtype]
         | 
| 497 |  | 
| 498 | 
            +
                gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field(
         | 
| 499 | 
            +
                    default=False
         | 
| 500 | 
            +
                )
         | 
| 501 | 
             
                gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
         | 
| 502 |  | 
| 503 | 
             
                unfrozen_parameters: Optional[List[str]] = None
         | 
    	
        src/axolotl/utils/gradient_checkpointing/__init__.py
    ADDED
    
    | @@ -0,0 +1,13 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """custom checkpointing utils"""
         | 
| 2 | 
            +
            from axolotl.utils.gradient_checkpointing.unsloth import (
         | 
| 3 | 
            +
                Unsloth_Offloaded_Gradient_Checkpointer,
         | 
| 4 | 
            +
            )
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def hf_grad_checkpoint_unsloth_wrapper(
         | 
| 8 | 
            +
                decoder_layer, *args, use_reentrant=None
         | 
| 9 | 
            +
            ):  # pylint: disable=unused-argument
         | 
| 10 | 
            +
                return Unsloth_Offloaded_Gradient_Checkpointer.apply(
         | 
| 11 | 
            +
                    decoder_layer.__self__,
         | 
| 12 | 
            +
                    *args,
         | 
| 13 | 
            +
                )
         | 
    	
        src/axolotl/utils/gradient_checkpointing/unsloth.py
    ADDED
    
    | @@ -0,0 +1,52 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Unsloth checkpointing"""
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 6 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 7 | 
            +
            # You may obtain a copy of the License at
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 12 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 13 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 14 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 15 | 
            +
            # limitations under the License.
         | 
| 16 | 
            +
            import torch
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            class Unsloth_Offloaded_Gradient_Checkpointer(  # pylint: disable=invalid-name
         | 
| 20 | 
            +
                torch.autograd.Function
         | 
| 21 | 
            +
            ):
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
                Saves VRAM by smartly offloading to RAM.
         | 
| 24 | 
            +
                Tiny hit to performance, since we mask the movement via non blocking calls.
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                @staticmethod
         | 
| 28 | 
            +
                @torch.cuda.amp.custom_fwd
         | 
| 29 | 
            +
                def forward(ctx, forward_function, hidden_states, *args):
         | 
| 30 | 
            +
                    saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
         | 
| 31 | 
            +
                    with torch.no_grad():
         | 
| 32 | 
            +
                        output = forward_function(hidden_states, *args)
         | 
| 33 | 
            +
                    ctx.save_for_backward(saved_hidden_states)
         | 
| 34 | 
            +
                    ctx.forward_function = forward_function
         | 
| 35 | 
            +
                    ctx.args = args
         | 
| 36 | 
            +
                    return output
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                @staticmethod
         | 
| 39 | 
            +
                @torch.cuda.amp.custom_bwd
         | 
| 40 | 
            +
                def backward(ctx, dY):
         | 
| 41 | 
            +
                    (hidden_states,) = ctx.saved_tensors
         | 
| 42 | 
            +
                    hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
         | 
| 43 | 
            +
                    hidden_states.requires_grad = True
         | 
| 44 | 
            +
                    with torch.enable_grad():
         | 
| 45 | 
            +
                        (output,) = ctx.forward_function(hidden_states, *ctx.args)
         | 
| 46 | 
            +
                    torch.autograd.backward(output, dY)
         | 
| 47 | 
            +
                    return (
         | 
| 48 | 
            +
                        None,
         | 
| 49 | 
            +
                        hidden_states.grad,
         | 
| 50 | 
            +
                    ) + (
         | 
| 51 | 
            +
                        None,
         | 
| 52 | 
            +
                    ) * len(ctx.args)
         | 
    	
        src/axolotl/utils/models.py
    CHANGED
    
    | @@ -11,6 +11,7 @@ import addict | |
| 11 | 
             
            import bitsandbytes as bnb
         | 
| 12 | 
             
            import torch
         | 
| 13 | 
             
            import transformers
         | 
|  | |
| 14 | 
             
            from accelerate import init_empty_weights
         | 
| 15 | 
             
            from bitsandbytes.nn import Params4bit
         | 
| 16 | 
             
            from peft import (
         | 
| @@ -44,6 +45,7 @@ from axolotl.utils.bench import log_gpu_memory_usage | |
| 44 | 
             
            from axolotl.utils.chat_templates import chat_templates
         | 
| 45 | 
             
            from axolotl.utils.dict import DictDefault
         | 
| 46 | 
             
            from axolotl.utils.distributed import zero_only
         | 
|  | |
| 47 | 
             
            from axolotl.utils.lora_embeddings import get_linear_embedding_layers
         | 
| 48 | 
             
            from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
         | 
| 49 |  | 
| @@ -310,6 +312,9 @@ def load_model( | |
| 310 | 
             
                # TODO refactor as a kwarg
         | 
| 311 | 
             
                load_in_8bit = cfg.load_in_8bit
         | 
| 312 |  | 
|  | |
|  | |
|  | |
| 313 | 
             
                if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
         | 
| 314 | 
             
                    if cfg.flash_attention:
         | 
| 315 | 
             
                        from axolotl.monkeypatch.btlm_attn_hijack_flash import (
         | 
|  | |
| 11 | 
             
            import bitsandbytes as bnb
         | 
| 12 | 
             
            import torch
         | 
| 13 | 
             
            import transformers
         | 
| 14 | 
            +
            import transformers.modeling_utils
         | 
| 15 | 
             
            from accelerate import init_empty_weights
         | 
| 16 | 
             
            from bitsandbytes.nn import Params4bit
         | 
| 17 | 
             
            from peft import (
         | 
|  | |
| 45 | 
             
            from axolotl.utils.chat_templates import chat_templates
         | 
| 46 | 
             
            from axolotl.utils.dict import DictDefault
         | 
| 47 | 
             
            from axolotl.utils.distributed import zero_only
         | 
| 48 | 
            +
            from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
         | 
| 49 | 
             
            from axolotl.utils.lora_embeddings import get_linear_embedding_layers
         | 
| 50 | 
             
            from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
         | 
| 51 |  | 
|  | |
| 312 | 
             
                # TODO refactor as a kwarg
         | 
| 313 | 
             
                load_in_8bit = cfg.load_in_8bit
         | 
| 314 |  | 
| 315 | 
            +
                if cfg.gradient_checkpointing == "unsloth":
         | 
| 316 | 
            +
                    transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
         | 
| 317 | 
            +
             | 
| 318 | 
             
                if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
         | 
| 319 | 
             
                    if cfg.flash_attention:
         | 
| 320 | 
             
                        from axolotl.monkeypatch.btlm_attn_hijack_flash import (
         | 
