|  | """Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune.""" | 
					
						
						|  | import glob | 
					
						
						|  | import json | 
					
						
						|  | import logging | 
					
						
						|  | import os.path | 
					
						
						|  | import shutil | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | from typing import Dict, List, Sequence | 
					
						
						|  |  | 
					
						
						|  | import bitsandbytes as bnb | 
					
						
						|  | import peft | 
					
						
						|  | import safetensors.torch as st | 
					
						
						|  | import torch | 
					
						
						|  | from huggingface_hub import snapshot_download | 
					
						
						|  | from torch.optim.lr_scheduler import LRScheduler | 
					
						
						|  | from torch.optim.optimizer import Optimizer | 
					
						
						|  | from transformers import ( | 
					
						
						|  | TrainerCallback, | 
					
						
						|  | TrainerControl, | 
					
						
						|  | TrainerState, | 
					
						
						|  | TrainingArguments, | 
					
						
						|  | ) | 
					
						
						|  | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR | 
					
						
						|  |  | 
					
						
						|  | from axolotl.utils.dict import DictDefault | 
					
						
						|  | from axolotl.utils.distributed import is_main_process | 
					
						
						|  |  | 
					
						
						|  | LOG = logging.getLogger("axolotl.relora") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def reset_optimizer(optimizer: torch.optim.Optimizer): | 
					
						
						|  | for group in optimizer.param_groups: | 
					
						
						|  | for param in group["params"]: | 
					
						
						|  | param_state = optimizer.state[param] | 
					
						
						|  | for key in param_state: | 
					
						
						|  | if "qmap" in key: | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | if key == "step" and isinstance(param_state[key], int): | 
					
						
						|  | param_state[key] = 0 | 
					
						
						|  | else: | 
					
						
						|  | param_state[key] = torch.zeros_like(param_state[key]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ReLoRACallback(TrainerCallback): | 
					
						
						|  | """Callback to merge LoRA weights into the base model and save full-weight checkpoints""" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, cfg: DictDefault): | 
					
						
						|  | self.relora_steps = cfg.relora_steps | 
					
						
						|  | self.cpu_offload = cfg.relora_cpu_offload | 
					
						
						|  | self.quantized = cfg.load_in_4bit or cfg.load_in_8bit | 
					
						
						|  | self.last_full_model = cfg.base_model | 
					
						
						|  | self.resume_from_checkpoint = cfg.resume_from_checkpoint | 
					
						
						|  |  | 
					
						
						|  | if not os.path.exists(self.last_full_model): | 
					
						
						|  | self.last_full_model = str(Path(snapshot_download(cfg.base_model))) | 
					
						
						|  |  | 
					
						
						|  | assert os.path.exists( | 
					
						
						|  | self.last_full_model | 
					
						
						|  | ), "for ReLORA base_model must be a local path" | 
					
						
						|  |  | 
					
						
						|  | self.num_lora_restarts = 0 | 
					
						
						|  | self.need_full_save = False | 
					
						
						|  |  | 
					
						
						|  | def on_train_begin( | 
					
						
						|  | self, | 
					
						
						|  | _args: TrainingArguments, | 
					
						
						|  | _state: TrainerState, | 
					
						
						|  | control: TrainerControl, | 
					
						
						|  | model: peft.LoraModel, | 
					
						
						|  | **_kwargs, | 
					
						
						|  | ): | 
					
						
						|  | if self.resume_from_checkpoint: | 
					
						
						|  | weight_path = os.path.join(self.resume_from_checkpoint, "relora") | 
					
						
						|  | if not os.path.exists(weight_path): | 
					
						
						|  | LOG.warning( | 
					
						
						|  | "Resuming ReLoRA from checkpoint, but no full-weight save found" | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | LOG.info(f"Loading adjusted base weights from {weight_path}") | 
					
						
						|  | load_weight_checkpoint(model, weight_path) | 
					
						
						|  | return control | 
					
						
						|  |  | 
					
						
						|  | def on_step_begin( | 
					
						
						|  | self, | 
					
						
						|  | args: TrainingArguments, | 
					
						
						|  | state: TrainerState, | 
					
						
						|  | control: TrainerControl, | 
					
						
						|  | model: peft.LoraModel, | 
					
						
						|  | optimizer: torch.optim.Optimizer, | 
					
						
						|  | **_kwargs, | 
					
						
						|  | ): | 
					
						
						|  | if state.global_step > 0 and state.global_step % self.relora_steps == 0: | 
					
						
						|  | checkpoint_folder = os.path.join( | 
					
						
						|  | args.output_dir, | 
					
						
						|  | f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", | 
					
						
						|  | "relora", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | merge_and_save( | 
					
						
						|  | model, | 
					
						
						|  | self.last_full_model, | 
					
						
						|  | checkpoint_folder, | 
					
						
						|  | reinit=True, | 
					
						
						|  | quantized=self.quantized, | 
					
						
						|  | actually_save=is_main_process(), | 
					
						
						|  | cpu_offload=self.cpu_offload, | 
					
						
						|  | ) | 
					
						
						|  | reset_optimizer(optimizer) | 
					
						
						|  |  | 
					
						
						|  | if self.quantized: | 
					
						
						|  | self.last_full_model = checkpoint_folder | 
					
						
						|  | self.num_lora_restarts += 1 | 
					
						
						|  |  | 
					
						
						|  | return control | 
					
						
						|  |  | 
					
						
						|  | def on_save( | 
					
						
						|  | self, | 
					
						
						|  | args: TrainingArguments, | 
					
						
						|  | state: TrainerState, | 
					
						
						|  | control: TrainerControl, | 
					
						
						|  | model: peft.LoraModel, | 
					
						
						|  | **_kwargs, | 
					
						
						|  | ): | 
					
						
						|  | checkpoint_folder = os.path.join( | 
					
						
						|  | args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", "relora" | 
					
						
						|  | ) | 
					
						
						|  | if ( | 
					
						
						|  | state.global_step >= self.relora_steps | 
					
						
						|  | and state.global_step % self.relora_steps != 0 | 
					
						
						|  | ): | 
					
						
						|  | if self.quantized: | 
					
						
						|  | if is_main_process() and self.last_full_model != checkpoint_folder: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | LOG.info(f"moving last full parameter save to {checkpoint_folder}") | 
					
						
						|  | os.makedirs(checkpoint_folder, exist_ok=True) | 
					
						
						|  | chunks = glob.glob( | 
					
						
						|  | f"{self.last_full_model}/model*.safetensors" | 
					
						
						|  | ) + glob.glob(f"{self.last_full_model}/model*.index.json") | 
					
						
						|  | for path in chunks: | 
					
						
						|  | new_path = os.path.abspath(shutil.move(path, checkpoint_folder)) | 
					
						
						|  | try: | 
					
						
						|  | os.symlink(new_path, path) | 
					
						
						|  | except OSError: | 
					
						
						|  |  | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | self.last_full_model = checkpoint_folder | 
					
						
						|  | else: | 
					
						
						|  | model.model.save_pretrained(checkpoint_folder, safe_serialization=True) | 
					
						
						|  |  | 
					
						
						|  | return control | 
					
						
						|  |  | 
					
						
						|  | def on_log( | 
					
						
						|  | self, | 
					
						
						|  | _args: TrainingArguments, | 
					
						
						|  | _state: TrainerState, | 
					
						
						|  | control: TrainerControl, | 
					
						
						|  | logs: Dict[str, float], | 
					
						
						|  | **_kwargs, | 
					
						
						|  | ): | 
					
						
						|  | logs["num_lora_restarts"] = self.num_lora_restarts | 
					
						
						|  | return control | 
					
						
						|  |  | 
					
						
						|  | def on_train_end( | 
					
						
						|  | self, | 
					
						
						|  | args: TrainingArguments, | 
					
						
						|  | _state: TrainerState, | 
					
						
						|  | control: TrainerControl, | 
					
						
						|  | model: peft.LoraModel, | 
					
						
						|  | **_kwargs, | 
					
						
						|  | ): | 
					
						
						|  | if self.quantized: | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | merge_and_save( | 
					
						
						|  | model, | 
					
						
						|  | self.last_full_model, | 
					
						
						|  | args.output_dir, | 
					
						
						|  | reinit=False, | 
					
						
						|  | quantized=self.quantized, | 
					
						
						|  | actually_save=is_main_process(), | 
					
						
						|  | cpu_offload=self.cpu_offload, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return control | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ReLoRAScheduler(LRScheduler): | 
					
						
						|  | """Wraps another scheduler to apply per-lora-restart learning rate warmups.""" | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | optimizer: Optimizer, | 
					
						
						|  | inner_schedule: LRScheduler, | 
					
						
						|  | relora_steps: int, | 
					
						
						|  | warmup_steps: int, | 
					
						
						|  | min_lr_scale: float = 0.001, | 
					
						
						|  | ) -> None: | 
					
						
						|  | self.inner_schedule = inner_schedule | 
					
						
						|  | self.relora_steps = relora_steps | 
					
						
						|  | self.warmup_steps = warmup_steps | 
					
						
						|  | self.min_lr_scale = min_lr_scale | 
					
						
						|  | super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose) | 
					
						
						|  |  | 
					
						
						|  | def get_lr(self) -> float: | 
					
						
						|  | self.inner_schedule.last_epoch = self.last_epoch | 
					
						
						|  |  | 
					
						
						|  | original = self.inner_schedule.get_lr() | 
					
						
						|  | step = self.last_epoch | 
					
						
						|  | if step < self.relora_steps: | 
					
						
						|  | scale = 1 | 
					
						
						|  | else: | 
					
						
						|  | cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps) | 
					
						
						|  | scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale | 
					
						
						|  |  | 
					
						
						|  | if isinstance(original, Sequence): | 
					
						
						|  | return [lr * scale for lr in original] | 
					
						
						|  | return original * scale | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]: | 
					
						
						|  | model_name = "model.safetensors" | 
					
						
						|  | if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists( | 
					
						
						|  | str(Path(path) / f"{model_name}.index.json") | 
					
						
						|  | ): | 
					
						
						|  | model_name = "pytorch_model.bin" | 
					
						
						|  |  | 
					
						
						|  | index_path = str(Path(path) / f"{model_name}.index.json") | 
					
						
						|  | if os.path.exists(index_path): | 
					
						
						|  | with open(index_path, "r", encoding="utf-8") as file: | 
					
						
						|  | data = json.load(file) | 
					
						
						|  | return data["weight_map"] | 
					
						
						|  | return {(module_name + ".weight"): model_name for module_name in module_names} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor: | 
					
						
						|  | if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)): | 
					
						
						|  | adapter = layer.active_adapter | 
					
						
						|  | return ( | 
					
						
						|  | peft.utils.transpose( | 
					
						
						|  | layer.lora_B[adapter].weight.detach().to(device) | 
					
						
						|  | @ layer.lora_A[adapter].weight.detach().to(device), | 
					
						
						|  | getattr(layer, "fan_in_fan_out", False), | 
					
						
						|  | ) | 
					
						
						|  | * layer.scaling[adapter] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return layer.get_delta_weight().to(device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]: | 
					
						
						|  | modules: Dict[str, peft.tuners.lora.LoraLayer] = {} | 
					
						
						|  |  | 
					
						
						|  | key_list = [key for key, _ in model.model.named_modules() if "lora" not in key] | 
					
						
						|  | for key in key_list: | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | _parent, target, _target_name = peft.utils._get_submodules(model.model, key) | 
					
						
						|  | except AttributeError: | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | if isinstance(target, peft.tuners.lora.LoraLayer): | 
					
						
						|  | modules[key] = target | 
					
						
						|  |  | 
					
						
						|  | return modules | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def update_weights( | 
					
						
						|  | target: peft.tuners.lora.LoraLayer, new_weight: torch.Tensor, reinit: bool, device | 
					
						
						|  | ): | 
					
						
						|  | if reinit: | 
					
						
						|  | for adapter_name in target.lora_A: | 
					
						
						|  | target.reset_lora_parameters(adapter_name) | 
					
						
						|  | for adapter_name in target.lora_embedding_A: | 
					
						
						|  | target.reset_lora_parameters(adapter_name) | 
					
						
						|  |  | 
					
						
						|  | if isinstance(target, peft.tuners.lora.Linear4bit): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | target.weight.quant_state = None | 
					
						
						|  | target.weight.data = new_weight.cpu() | 
					
						
						|  | target.to(device) | 
					
						
						|  | elif isinstance(target, peft.tuners.lora.Linear8bitLt): | 
					
						
						|  | target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device) | 
					
						
						|  | else: | 
					
						
						|  | target.weight.data = new_weight.to(device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def merge_and_save( | 
					
						
						|  | model: peft.LoraModel, | 
					
						
						|  | model_src: str, | 
					
						
						|  | model_dst: str, | 
					
						
						|  | reinit: bool = False, | 
					
						
						|  | quantized: bool = False, | 
					
						
						|  | cpu_offload: bool = False, | 
					
						
						|  | actually_save: bool = True, | 
					
						
						|  | ): | 
					
						
						|  | modules = find_lora_modules(model) | 
					
						
						|  |  | 
					
						
						|  | if not quantized: | 
					
						
						|  | for module_name, target in modules.items(): | 
					
						
						|  | update = target.get_delta_weight(target.active_adapter).detach() | 
					
						
						|  | target.weight.data += update | 
					
						
						|  |  | 
					
						
						|  | if reinit: | 
					
						
						|  | for adapter_name in target.lora_A: | 
					
						
						|  | target.reset_lora_parameters(adapter_name) | 
					
						
						|  | for adapter_name in target.lora_embedding_A: | 
					
						
						|  | target.reset_lora_parameters(adapter_name) | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  | os.makedirs(model_dst, exist_ok=True) | 
					
						
						|  | shard_paths = sharded_paths(model_src, modules.keys()) | 
					
						
						|  | out_shard_paths = {} | 
					
						
						|  |  | 
					
						
						|  | unique_shards = list(set(shard_paths.values())) | 
					
						
						|  | for shard_path in unique_shards: | 
					
						
						|  | out_tensors = {} | 
					
						
						|  | if shard_path.endswith(".safetensors"): | 
					
						
						|  | in_tensors = st.load_file(str(Path(model_src) / shard_path)) | 
					
						
						|  | else: | 
					
						
						|  | in_tensors = torch.load(Path(model_src) / shard_path) | 
					
						
						|  | if "state_dict" in in_tensors: | 
					
						
						|  | in_tensors = in_tensors["state_dict"] | 
					
						
						|  |  | 
					
						
						|  | for module_name, target in modules.items(): | 
					
						
						|  | key = module_name + ".weight" | 
					
						
						|  | if key not in shard_paths or shard_paths[key] != shard_path: | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | orig_weight = in_tensors[key] | 
					
						
						|  | old_dev = target.weight.device | 
					
						
						|  | math_dev = "cpu" if cpu_offload else old_dev | 
					
						
						|  |  | 
					
						
						|  | delta_weight = lora_delta_weight(target, math_dev) | 
					
						
						|  | new_weight = orig_weight.to(math_dev) + delta_weight | 
					
						
						|  | del delta_weight | 
					
						
						|  |  | 
					
						
						|  | if actually_save: | 
					
						
						|  | out_tensors[key] = new_weight.half().cpu() | 
					
						
						|  |  | 
					
						
						|  | update_weights(target, new_weight, reinit=reinit, device=old_dev) | 
					
						
						|  |  | 
					
						
						|  | if actually_save: | 
					
						
						|  | out_shard_name = shard_path | 
					
						
						|  | if out_shard_name.startswith("pytorch_model"): | 
					
						
						|  | out_shard_name = ( | 
					
						
						|  | out_shard_name.replace("pytorch_model", "model").rstrip(".bin") | 
					
						
						|  | + ".safetensors" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | for module_name in in_tensors: | 
					
						
						|  | if module_name not in out_tensors: | 
					
						
						|  | out_tensors[module_name] = in_tensors[module_name].half() | 
					
						
						|  | out_shard_paths[module_name] = out_shard_name | 
					
						
						|  |  | 
					
						
						|  | shard_fn = str(Path(model_dst) / out_shard_name) | 
					
						
						|  | LOG.info(f"saving tensors to {shard_fn}") | 
					
						
						|  | st.save_file(out_tensors, shard_fn, metadata={"format": "pt"}) | 
					
						
						|  |  | 
					
						
						|  | del in_tensors | 
					
						
						|  | del out_tensors | 
					
						
						|  | torch.cuda.empty_cache() | 
					
						
						|  |  | 
					
						
						|  | if actually_save and len(unique_shards) > 1: | 
					
						
						|  | with open( | 
					
						
						|  | str(Path(model_dst, "model.safetensors.index.json")), "w", encoding="utf-8" | 
					
						
						|  | ) as file: | 
					
						
						|  | json.dump({"metadata": {}, "weight_map": out_shard_paths}, file) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_weight_checkpoint(model: peft.LoraModel, checkpoint_path: str): | 
					
						
						|  | modules = find_lora_modules(model) | 
					
						
						|  | shard_paths = sharded_paths(checkpoint_path, modules.keys()) | 
					
						
						|  | unique_shards = list(set(shard_paths.values())) | 
					
						
						|  |  | 
					
						
						|  | for shard_path in unique_shards: | 
					
						
						|  | tensors = st.load_file(os.path.join(checkpoint_path, shard_path)) | 
					
						
						|  |  | 
					
						
						|  | for module_name, target in modules.items(): | 
					
						
						|  | key = module_name + ".weight" | 
					
						
						|  | if key not in shard_paths or shard_paths[key] != shard_path: | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | new_weight = tensors[key] | 
					
						
						|  | update_weights( | 
					
						
						|  | target, new_weight, reinit=False, device=target.weight.device | 
					
						
						|  | ) | 
					
						
						|  |  |