How to use LoRA with this model in Python?
#42
by
megatrump
- opened
Hi, thank you very much for your contribution! I have successfully used this quantized model in ComfyUI
and successfully loaded the LoRA. Similarly, I can load this model using stable-diffusion.cpp
and use LoRA. However, I encountered some issues. I would like to use Python to manually load the model and manage the model lifecycle. Specifically, I used the code mentioned in #41 and made modifications:
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig
checkpoint = "models/flux1-dev-q4_0.gguf"
transformer = FluxTransformer2DModel.from_single_file(
checkpoint,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
)
pipe = pipeline.to("cuda")
pipe.enable_model_cpu_offload()
pipe.load_lora_weights("models/LoRAs/30.safetensors")
Then, I received the following error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[7], line 2
1 with torch.inference_mode():
----> 2 pipe.load_lora_weights("/home/xxx/models/Flux-1.dev-Q4/30.safetensors")
File ~/Services/FluxLoRAQuantitativeQuest/.venv/lib/python3.12/site-packages/diffusers/loaders/lora_pipeline.py:1550, in FluxLoraLoaderMixin.load_lora_weights(self, pretrained_model_name_or_path_or_dict, adapter_name, **kwargs)
1543 transformer_norm_state_dict = {
1544 k: state_dict.pop(k)
1545 for k in list(state_dict.keys())
1546 if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
1547 }
1549 transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
-> 1550 has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
1551 transformer, transformer_lora_state_dict, transformer_norm_state_dict
1552 )
1554 if has_param_with_expanded_shape:
1555 logger.info(
1556 "The LoRA weights contain parameters that have different shapes that expected by the transformer. "
1557 "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
1558 "To get a comprehensive list of parameter names that were modified, enable debug logging."
1559 )
File ~/Services/FluxLoRAQuantitativeQuest/.venv/lib/python3.12/site-packages/diffusers/loaders/lora_pipeline.py:2020, in FluxLoraLoaderMixin._maybe_expand_transformer_param_shape_or_error_(cls, transformer, lora_state_dict, norm_state_dict, prefix)
2017 parent_module = transformer.get_submodule(parent_module_name)
2019 with torch.device("meta"):
-> 2020 expanded_module = torch.nn.Linear(
2021 in_features, out_features, bias=bias, dtype=module_weight.dtype
2022 )
2023 # Only weights are expanded and biases are not. This is because only the input dimensions
2024 # are changed while the output dimensions remain the same. The shape of the weight tensor
2025 # is (out_features, in_features), while the shape of bias tensor is (out_features,), which
2026 # explains the reason why only weights are expanded.
2027 new_weight = torch.zeros_like(
2028 expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
2029 )
File ~/Services/FluxLoRAQuantitativeQuest/.venv/lib/python3.12/site-packages/torch/nn/modules/linear.py:105, in Linear.__init__(self, in_features, out_features, bias, device, dtype)
103 self.in_features = in_features
104 self.out_features = out_features
--> 105 self.weight = Parameter(
106 torch.empty((out_features, in_features), **factory_kwargs)
107 )
108 if bias:
109 self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
File ~/Services/FluxLoRAQuantitativeQuest/.venv/lib/python3.12/site-packages/torch/nn/parameter.py:46, in Parameter.__new__(cls, data, requires_grad)
42 data = torch.empty(0)
43 if type(data) is torch.Tensor or type(data) is Parameter:
44 # For ease of BC maintenance, keep this path for standard Tensor.
45 # Eventually (tm), we should change the behavior for standard Tensor to match.
---> 46 return torch.Tensor._make_subclass(cls, data, requires_grad)
48 # Path for custom tensors: set a flag on the instance to indicate parameter-ness.
49 t = data.detach().requires_grad_(requires_grad)
RuntimeError: Only Tensors of floating point and complex dtype can require gradients
How should I load my LoRA? I’m really looking forward to your reply!