# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import Callable, Dict, List, Optional, Union import torch from huggingface_hub.utils import validate_hf_hub_args from diffusers.utils import ( USE_PEFT_BACKEND, is_peft_available, is_peft_version, is_torch_version, is_transformers_available, is_transformers_version, logging, ) from diffusers.loaders.lora_base import ( # noqa LoraBaseMixin, _fetch_state_dict, ) from diffusers.loaders.lora_conversion_utils import ( _convert_non_diffusers_lumina2_lora_to_diffusers, ) _LOW_CPU_MEM_USAGE_DEFAULT_LORA = False if is_torch_version(">=", "1.9.0"): if ( is_peft_available() and is_peft_version(">=", "0.13.1") and is_transformers_available() and is_transformers_version(">", "4.45.2") ): _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True logger = logging.get_logger(__name__) TRANSFORMER_NAME = "transformer" class OmniGen2LoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`OmniGen2Transformer2DModel`]. Specific to [`OmniGen2Pipeline`]. """ _lora_loadable_modules = ["transformer"] transformer_name = TRANSFORMER_NAME @classmethod @validate_hf_hub_args def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs, ): r""" Return state dict for lora weights and the network alphas. We support loading A1111 formatted LoRA checkpoints in a limited capacity. This function is experimental and might change in the future. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`ModelMixin.save_pretrained`]. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True user_agent = { "file_type": "attn_procs_weights", "framework": "pytorch", } state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, proxies=proxies, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} # conversion. non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict) if non_diffusers: state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict) return state_dict # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state dict is loaded into `self.transformer`. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." ) # if a dict is passed, copy it instead of modifying it inplace if isinstance(pretrained_model_name_or_path_or_dict, dict): pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") self.load_lora_into_transformer( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel def load_lora_into_transformer( cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. Parameters: state_dict (`dict`): A standard state dict containing the lora layer parameters. The keys can either be indexed directly into the unet or prefixed with an additional `unet` which can be used to distinguish between text encoder lora layers. transformer (`Lumina2Transformer2DModel`): The Transformer model to load the LoRA layers into. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap : (`bool`, *optional*) Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place. This means that, instead of loading an additional adapter, this will take the existing adapter weights and replace them with the weights of the new adapter. This can be faster and more memory efficient. However, the main advantage of hotswapping is that when the model is compiled with torch.compile, loading the new adapter does not require recompilation of the model. When using hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an additional method before loading the adapter: ```py pipeline = ... # load diffusers pipeline max_rank = ... # the highest rank among all LoRAs that you want to load # call *before* compiling and loading the LoRA adapter pipeline.enable_lora_hotswap(target_rank=max_rank) pipeline.load_lora_weights(file_name) # optionally compile the model now ``` Note that hotswapping adapters of the text encoder is not yet supported. There are some further limitations to this technique, which are documented here: https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." ) # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") transformer.load_lora_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights def save_lora_weights( cls, save_directory: Union[str, os.PathLike], transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): Directory to save LoRA parameters to. Will be created if it doesn't exist. transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `transformer`. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ state_dict = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) # Save the model cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], lora_scale: float = 1.0, safe_fusing: bool = False, adapter_names: Optional[List[str]] = None, **kwargs, ): r""" Fuses the LoRA parameters into the original parameters of the corresponding blocks. This is an experimental API. Args: components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. lora_scale (`float`, defaults to 1.0): Controls how much to influence the outputs with the LoRA parameters. safe_fusing (`bool`, defaults to `False`): Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. adapter_names (`List[str]`, *optional*): Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. Example: ```py from diffusers import DiffusionPipeline import torch pipeline = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ).to("cuda") pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") pipeline.fuse_lora(lora_scale=0.7) ``` """ super().fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names, **kwargs, ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). This is an experimental API. Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. """ super().unfuse_lora(components=components, **kwargs)