Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
from typing import Callable, Dict, List, Optional, Union | |
import torch | |
from huggingface_hub.utils import validate_hf_hub_args | |
from diffusers.utils import ( | |
USE_PEFT_BACKEND, | |
is_peft_available, | |
is_peft_version, | |
is_torch_version, | |
is_transformers_available, | |
is_transformers_version, | |
logging, | |
) | |
from diffusers.loaders.lora_base import ( # noqa | |
LoraBaseMixin, | |
_fetch_state_dict, | |
) | |
from diffusers.loaders.lora_conversion_utils import ( | |
_convert_non_diffusers_lumina2_lora_to_diffusers, | |
) | |
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False | |
if is_torch_version(">=", "1.9.0"): | |
if ( | |
is_peft_available() | |
and is_peft_version(">=", "0.13.1") | |
and is_transformers_available() | |
and is_transformers_version(">", "4.45.2") | |
): | |
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = True | |
logger = logging.get_logger(__name__) | |
TRANSFORMER_NAME = "transformer" | |
class OmniGen2LoraLoaderMixin(LoraBaseMixin): | |
r""" | |
Load LoRA layers into [`OmniGen2Transformer2DModel`]. Specific to [`OmniGen2Pipeline`]. | |
""" | |
_lora_loadable_modules = ["transformer"] | |
transformer_name = TRANSFORMER_NAME | |
def lora_state_dict( | |
cls, | |
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], | |
**kwargs, | |
): | |
r""" | |
Return state dict for lora weights and the network alphas. | |
<Tip warning={true}> | |
We support loading A1111 formatted LoRA checkpoints in a limited capacity. | |
This function is experimental and might change in the future. | |
</Tip> | |
Parameters: | |
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): | |
Can be either: | |
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on | |
the Hub. | |
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved | |
with [`ModelMixin.save_pretrained`]. | |
- A [torch state | |
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). | |
cache_dir (`Union[str, os.PathLike]`, *optional*): | |
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache | |
is not used. | |
force_download (`bool`, *optional*, defaults to `False`): | |
Whether or not to force the (re-)download of the model weights and configuration files, overriding the | |
cached versions if they exist. | |
proxies (`Dict[str, str]`, *optional*): | |
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', | |
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | |
local_files_only (`bool`, *optional*, defaults to `False`): | |
Whether to only load local model weights and configuration files or not. If set to `True`, the model | |
won't be downloaded from the Hub. | |
token (`str` or *bool*, *optional*): | |
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from | |
`diffusers-cli login` (stored in `~/.huggingface`) is used. | |
revision (`str`, *optional*, defaults to `"main"`): | |
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier | |
allowed by Git. | |
subfolder (`str`, *optional*, defaults to `""`): | |
The subfolder location of a model file within a larger model repository on the Hub or locally. | |
""" | |
# Load the main state dict first which has the LoRA layers for either of | |
# transformer and text encoder or both. | |
cache_dir = kwargs.pop("cache_dir", None) | |
force_download = kwargs.pop("force_download", False) | |
proxies = kwargs.pop("proxies", None) | |
local_files_only = kwargs.pop("local_files_only", None) | |
token = kwargs.pop("token", None) | |
revision = kwargs.pop("revision", None) | |
subfolder = kwargs.pop("subfolder", None) | |
weight_name = kwargs.pop("weight_name", None) | |
use_safetensors = kwargs.pop("use_safetensors", None) | |
allow_pickle = False | |
if use_safetensors is None: | |
use_safetensors = True | |
allow_pickle = True | |
user_agent = { | |
"file_type": "attn_procs_weights", | |
"framework": "pytorch", | |
} | |
state_dict = _fetch_state_dict( | |
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, | |
weight_name=weight_name, | |
use_safetensors=use_safetensors, | |
local_files_only=local_files_only, | |
cache_dir=cache_dir, | |
force_download=force_download, | |
proxies=proxies, | |
token=token, | |
revision=revision, | |
subfolder=subfolder, | |
user_agent=user_agent, | |
allow_pickle=allow_pickle, | |
) | |
is_dora_scale_present = any("dora_scale" in k for k in state_dict) | |
if is_dora_scale_present: | |
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." | |
logger.warning(warn_msg) | |
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} | |
# conversion. | |
non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict) | |
if non_diffusers: | |
state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict) | |
return state_dict | |
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights | |
def load_lora_weights( | |
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs | |
): | |
""" | |
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and | |
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See | |
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. | |
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state | |
dict is loaded into `self.transformer`. | |
Parameters: | |
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): | |
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. | |
adapter_name (`str`, *optional*): | |
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use | |
`default_{i}` where i is the total number of adapters being loaded. | |
low_cpu_mem_usage (`bool`, *optional*): | |
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random | |
weights. | |
kwargs (`dict`, *optional*): | |
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. | |
""" | |
if not USE_PEFT_BACKEND: | |
raise ValueError("PEFT backend is required for this method.") | |
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) | |
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): | |
raise ValueError( | |
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." | |
) | |
# if a dict is passed, copy it instead of modifying it inplace | |
if isinstance(pretrained_model_name_or_path_or_dict, dict): | |
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() | |
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. | |
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) | |
is_correct_format = all("lora" in key for key in state_dict.keys()) | |
if not is_correct_format: | |
raise ValueError("Invalid LoRA checkpoint.") | |
self.load_lora_into_transformer( | |
state_dict, | |
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, | |
adapter_name=adapter_name, | |
_pipeline=self, | |
low_cpu_mem_usage=low_cpu_mem_usage, | |
) | |
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel | |
def load_lora_into_transformer( | |
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False | |
): | |
""" | |
This will load the LoRA layers specified in `state_dict` into `transformer`. | |
Parameters: | |
state_dict (`dict`): | |
A standard state dict containing the lora layer parameters. The keys can either be indexed directly | |
into the unet or prefixed with an additional `unet` which can be used to distinguish between text | |
encoder lora layers. | |
transformer (`Lumina2Transformer2DModel`): | |
The Transformer model to load the LoRA layers into. | |
adapter_name (`str`, *optional*): | |
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use | |
`default_{i}` where i is the total number of adapters being loaded. | |
low_cpu_mem_usage (`bool`, *optional*): | |
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random | |
weights. | |
hotswap : (`bool`, *optional*) | |
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter | |
in-place. This means that, instead of loading an additional adapter, this will take the existing | |
adapter weights and replace them with the weights of the new adapter. This can be faster and more | |
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with | |
torch.compile, loading the new adapter does not require recompilation of the model. When using | |
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. | |
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need | |
to call an additional method before loading the adapter: | |
```py | |
pipeline = ... # load diffusers pipeline | |
max_rank = ... # the highest rank among all LoRAs that you want to load | |
# call *before* compiling and loading the LoRA adapter | |
pipeline.enable_lora_hotswap(target_rank=max_rank) | |
pipeline.load_lora_weights(file_name) | |
# optionally compile the model now | |
``` | |
Note that hotswapping adapters of the text encoder is not yet supported. There are some further | |
limitations to this technique, which are documented here: | |
https://huggingface.co/docs/peft/main/en/package_reference/hotswap | |
""" | |
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): | |
raise ValueError( | |
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." | |
) | |
# Load the layers corresponding to transformer. | |
logger.info(f"Loading {cls.transformer_name}.") | |
transformer.load_lora_adapter( | |
state_dict, | |
network_alphas=None, | |
adapter_name=adapter_name, | |
_pipeline=_pipeline, | |
low_cpu_mem_usage=low_cpu_mem_usage, | |
hotswap=hotswap, | |
) | |
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights | |
def save_lora_weights( | |
cls, | |
save_directory: Union[str, os.PathLike], | |
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, | |
is_main_process: bool = True, | |
weight_name: str = None, | |
save_function: Callable = None, | |
safe_serialization: bool = True, | |
): | |
r""" | |
Save the LoRA parameters corresponding to the UNet and text encoder. | |
Arguments: | |
save_directory (`str` or `os.PathLike`): | |
Directory to save LoRA parameters to. Will be created if it doesn't exist. | |
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): | |
State dict of the LoRA layers corresponding to the `transformer`. | |
is_main_process (`bool`, *optional*, defaults to `True`): | |
Whether the process calling this is the main process or not. Useful during distributed training and you | |
need to call this function on all processes. In this case, set `is_main_process=True` only on the main | |
process to avoid race conditions. | |
save_function (`Callable`): | |
The function to use to save the state dictionary. Useful during distributed training when you need to | |
replace `torch.save` with another method. Can be configured with the environment variable | |
`DIFFUSERS_SAVE_MODE`. | |
safe_serialization (`bool`, *optional*, defaults to `True`): | |
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. | |
""" | |
state_dict = {} | |
if not transformer_lora_layers: | |
raise ValueError("You must pass `transformer_lora_layers`.") | |
if transformer_lora_layers: | |
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) | |
# Save the model | |
cls.write_lora_layers( | |
state_dict=state_dict, | |
save_directory=save_directory, | |
is_main_process=is_main_process, | |
weight_name=weight_name, | |
save_function=save_function, | |
safe_serialization=safe_serialization, | |
) | |
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora | |
def fuse_lora( | |
self, | |
components: List[str] = ["transformer"], | |
lora_scale: float = 1.0, | |
safe_fusing: bool = False, | |
adapter_names: Optional[List[str]] = None, | |
**kwargs, | |
): | |
r""" | |
Fuses the LoRA parameters into the original parameters of the corresponding blocks. | |
<Tip warning={true}> | |
This is an experimental API. | |
</Tip> | |
Args: | |
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. | |
lora_scale (`float`, defaults to 1.0): | |
Controls how much to influence the outputs with the LoRA parameters. | |
safe_fusing (`bool`, defaults to `False`): | |
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. | |
adapter_names (`List[str]`, *optional*): | |
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. | |
Example: | |
```py | |
from diffusers import DiffusionPipeline | |
import torch | |
pipeline = DiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 | |
).to("cuda") | |
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") | |
pipeline.fuse_lora(lora_scale=0.7) | |
``` | |
""" | |
super().fuse_lora( | |
components=components, | |
lora_scale=lora_scale, | |
safe_fusing=safe_fusing, | |
adapter_names=adapter_names, | |
**kwargs, | |
) | |
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora | |
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): | |
r""" | |
Reverses the effect of | |
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). | |
<Tip warning={true}> | |
This is an experimental API. | |
</Tip> | |
Args: | |
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. | |
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. | |
""" | |
super().unfuse_lora(components=components, **kwargs) |