How to use LoRA with this model in Python?

#42
by megatrump - opened

Hi, thank you very much for your contribution! I have successfully used this quantized model in ComfyUI and successfully loaded the LoRA. Similarly, I can load this model using stable-diffusion.cpp and use LoRA. However, I encountered some issues. I would like to use Python to manually load the model and manage the model lifecycle. Specifically, I used the code mentioned in #41 and made modifications:

import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig

checkpoint = "models/flux1-dev-q4_0.gguf"
transformer = FluxTransformer2DModel.from_single_file(
    checkpoint,
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
)

pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    transformer=transformer,
    torch_dtype=torch.bfloat16,
)

pipe = pipeline.to("cuda")
pipe.enable_model_cpu_offload()

pipe.load_lora_weights("models/LoRAs/30.safetensors")

Then, I received the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[7], line 2
      1 with torch.inference_mode():
----> 2     pipe.load_lora_weights("/home/xxx/models/Flux-1.dev-Q4/30.safetensors")

File ~/Services/FluxLoRAQuantitativeQuest/.venv/lib/python3.12/site-packages/diffusers/loaders/lora_pipeline.py:1550, in FluxLoraLoaderMixin.load_lora_weights(self, pretrained_model_name_or_path_or_dict, adapter_name, **kwargs)
   1543 transformer_norm_state_dict = {
   1544     k: state_dict.pop(k)
   1545     for k in list(state_dict.keys())
   1546     if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
   1547 }
   1549 transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
-> 1550 has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
   1551     transformer, transformer_lora_state_dict, transformer_norm_state_dict
   1552 )
   1554 if has_param_with_expanded_shape:
   1555     logger.info(
   1556         "The LoRA weights contain parameters that have different shapes that expected by the transformer. "
   1557         "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
   1558         "To get a comprehensive list of parameter names that were modified, enable debug logging."
   1559     )

File ~/Services/FluxLoRAQuantitativeQuest/.venv/lib/python3.12/site-packages/diffusers/loaders/lora_pipeline.py:2020, in FluxLoraLoaderMixin._maybe_expand_transformer_param_shape_or_error_(cls, transformer, lora_state_dict, norm_state_dict, prefix)
   2017 parent_module = transformer.get_submodule(parent_module_name)
   2019 with torch.device("meta"):
-> 2020     expanded_module = torch.nn.Linear(
   2021         in_features, out_features, bias=bias, dtype=module_weight.dtype
   2022     )
   2023 # Only weights are expanded and biases are not. This is because only the input dimensions
   2024 # are changed while the output dimensions remain the same. The shape of the weight tensor
   2025 # is (out_features, in_features), while the shape of bias tensor is (out_features,), which
   2026 # explains the reason why only weights are expanded.
   2027 new_weight = torch.zeros_like(
   2028     expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
   2029 )

File ~/Services/FluxLoRAQuantitativeQuest/.venv/lib/python3.12/site-packages/torch/nn/modules/linear.py:105, in Linear.__init__(self, in_features, out_features, bias, device, dtype)
    103 self.in_features = in_features
    104 self.out_features = out_features
--> 105 self.weight = Parameter(
    106     torch.empty((out_features, in_features), **factory_kwargs)
    107 )
    108 if bias:
    109     self.bias = Parameter(torch.empty(out_features, **factory_kwargs))

File ~/Services/FluxLoRAQuantitativeQuest/.venv/lib/python3.12/site-packages/torch/nn/parameter.py:46, in Parameter.__new__(cls, data, requires_grad)
     42     data = torch.empty(0)
     43 if type(data) is torch.Tensor or type(data) is Parameter:
     44     # For ease of BC maintenance, keep this path for standard Tensor.
     45     # Eventually (tm), we should change the behavior for standard Tensor to match.
---> 46     return torch.Tensor._make_subclass(cls, data, requires_grad)
     48 # Path for custom tensors: set a flag on the instance to indicate parameter-ness.
     49 t = data.detach().requires_grad_(requires_grad)

RuntimeError: Only Tensors of floating point and complex dtype can require gradients

How should I load my LoRA? I’m really looking forward to your reply!

Your need to confirm your account before you can post a new comment.

Sign up or log in to comment