allow the optimizer prune ratio for ReLoRA to be configurable (#1287)
Browse files* allow the optimizer prune ration for relora to be configurable
* update docs for relora
* prevent circular imports
- README.md +2 -0
- src/axolotl/core/trainer_builder.py +18 -3
- src/axolotl/monkeypatch/relora.py +3 -1
    	
        README.md
    CHANGED
    
    | @@ -734,6 +734,8 @@ peft: | |
| 734 | 
             
            # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
         | 
| 735 | 
             
            relora_steps: # Number of steps per ReLoRA restart
         | 
| 736 | 
             
            relora_warmup_steps: # Number of per-restart warmup steps
         | 
|  | |
|  | |
| 737 | 
             
            relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
         | 
| 738 |  | 
| 739 | 
             
            # wandb configuration if you're using it
         | 
|  | |
| 734 | 
             
            # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
         | 
| 735 | 
             
            relora_steps: # Number of steps per ReLoRA restart
         | 
| 736 | 
             
            relora_warmup_steps: # Number of per-restart warmup steps
         | 
| 737 | 
            +
            relora_anneal_steps: # Number of anneal steps for each relora cycle
         | 
| 738 | 
            +
            relora_prune_ratio: # threshold for optimizer magnitude when pruning
         | 
| 739 | 
             
            relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
         | 
| 740 |  | 
| 741 | 
             
            # wandb configuration if you're using it
         | 
    	
        src/axolotl/core/trainer_builder.py
    CHANGED
    
    | @@ -131,6 +131,10 @@ class AxolotlTrainingArguments(TrainingArguments): | |
| 131 | 
             
                    default=None,
         | 
| 132 | 
             
                    metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
         | 
| 133 | 
             
                )
         | 
|  | |
|  | |
|  | |
|  | |
| 134 | 
             
                bench_split: Optional[str] = field(
         | 
| 135 | 
             
                    default="eval", metadata={"help": "The benchmark split to run on"}
         | 
| 136 | 
             
                )
         | 
| @@ -900,9 +904,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): | |
| 900 | 
             
                    training_arguments_kwargs[
         | 
| 901 | 
             
                        "sample_packing_seq_len_multiplier"
         | 
| 902 | 
             
                    ] = self.cfg.micro_batch_size
         | 
| 903 | 
            -
                     | 
| 904 | 
            -
             | 
| 905 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 906 | 
             
                    training_arguments_kwargs = self.hook_pre_create_training_args(
         | 
| 907 | 
             
                        training_arguments_kwargs
         | 
| 908 | 
             
                    )
         | 
|  | |
| 131 | 
             
                    default=None,
         | 
| 132 | 
             
                    metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
         | 
| 133 | 
             
                )
         | 
| 134 | 
            +
                relora_prune_ratio: Optional[float] = field(
         | 
| 135 | 
            +
                    default=0.9,
         | 
| 136 | 
            +
                    metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
         | 
| 137 | 
            +
                )
         | 
| 138 | 
             
                bench_split: Optional[str] = field(
         | 
| 139 | 
             
                    default="eval", metadata={"help": "The benchmark split to run on"}
         | 
| 140 | 
             
                )
         | 
|  | |
| 904 | 
             
                    training_arguments_kwargs[
         | 
| 905 | 
             
                        "sample_packing_seq_len_multiplier"
         | 
| 906 | 
             
                    ] = self.cfg.micro_batch_size
         | 
| 907 | 
            +
                    if self.cfg.relora_steps:
         | 
| 908 | 
            +
                        training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
         | 
| 909 | 
            +
                        training_arguments_kwargs[
         | 
| 910 | 
            +
                            "relora_warmup_steps"
         | 
| 911 | 
            +
                        ] = self.cfg.relora_warmup_steps
         | 
| 912 | 
            +
                        if self.cfg.relora_anneal_steps:
         | 
| 913 | 
            +
                            training_arguments_kwargs[
         | 
| 914 | 
            +
                                "relora_anneal_steps"
         | 
| 915 | 
            +
                            ] = self.cfg.relora_anneal_steps
         | 
| 916 | 
            +
                        if self.cfg.relora_prune_ratio:
         | 
| 917 | 
            +
                            training_arguments_kwargs[
         | 
| 918 | 
            +
                                "relora_prune_ratio"
         | 
| 919 | 
            +
                            ] = self.cfg.relora_prune_ratio
         | 
| 920 | 
            +
             | 
| 921 | 
             
                    training_arguments_kwargs = self.hook_pre_create_training_args(
         | 
| 922 | 
             
                        training_arguments_kwargs
         | 
| 923 | 
             
                    )
         | 
    	
        src/axolotl/monkeypatch/relora.py
    CHANGED
    
    | @@ -46,8 +46,9 @@ def reset_optimizer( | |
| 46 | 
             
                *,
         | 
| 47 | 
             
                reset_params: list[str],  # where str is the key to a torch.nn.Parameter
         | 
| 48 | 
             
                optimizer_state_keys: list[str],
         | 
|  | |
| 49 | 
             
            ):
         | 
| 50 | 
            -
                pruning_fn = partial(magnitude_pruning_, prune_ratio= | 
| 51 | 
             
                n_zeros = 0
         | 
| 52 | 
             
                n_total = 0
         | 
| 53 |  | 
| @@ -159,6 +160,7 @@ class ReLoRACallback(TrainerCallback): | |
| 159 | 
             
                                optimizer,
         | 
| 160 | 
             
                                reset_params=lora_params,
         | 
| 161 | 
             
                                optimizer_state_keys=optimizer_state_keys,
         | 
|  | |
| 162 | 
             
                            )
         | 
| 163 |  | 
| 164 | 
             
                        if self.quantized:
         | 
|  | |
| 46 | 
             
                *,
         | 
| 47 | 
             
                reset_params: list[str],  # where str is the key to a torch.nn.Parameter
         | 
| 48 | 
             
                optimizer_state_keys: list[str],
         | 
| 49 | 
            +
                prune_ratio: float = 0.9,
         | 
| 50 | 
             
            ):
         | 
| 51 | 
            +
                pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
         | 
| 52 | 
             
                n_zeros = 0
         | 
| 53 | 
             
                n_total = 0
         | 
| 54 |  | 
|  | |
| 160 | 
             
                                optimizer,
         | 
| 161 | 
             
                                reset_params=lora_params,
         | 
| 162 | 
             
                                optimizer_state_keys=optimizer_state_keys,
         | 
| 163 | 
            +
                                prune_ratio=args.relora_prune_ratio,
         | 
| 164 | 
             
                            )
         | 
| 165 |  | 
| 166 | 
             
                        if self.quantized:
         | 
