Spaces:
Runtime error
Runtime error
| from safetensors.torch import load_file | |
| import torch | |
| from tqdm import tqdm | |
| __all__ = [ | |
| 'flux_load_lora' | |
| ] | |
| def is_int(d): | |
| try: | |
| d = int(d) | |
| return True | |
| except Exception as e: | |
| return False | |
| def flux_load_lora(self, lora_file, lora_weight=1.0): | |
| device = self.transformer.device | |
| # DiT 部分 | |
| state_dict, network_alphas = self.lora_state_dict(lora_file, return_alphas=True) | |
| state_dict = {k:v.to(device) for k,v in state_dict.items()} | |
| model = self.transformer | |
| keys = list(state_dict.keys()) | |
| keys = [k for k in keys if k.startswith('transformer.')] | |
| for k_lora in tqdm(keys, total=len(keys), desc=f"loading lora in transformer ..."): | |
| v_lora = state_dict[k_lora] | |
| # 非 up 的都跳过 | |
| if '.lora_A.weight' in k_lora: | |
| continue | |
| if '.alpha' in k_lora: | |
| continue | |
| k_lora_name = k_lora.replace("transformer.", "") | |
| k_lora_name = k_lora_name.replace(".lora_B.weight", "") | |
| attr_name_list = k_lora_name.split('.') | |
| cur_attr = model | |
| latest_attr_name = '' | |
| for idx in range(0, len(attr_name_list)): | |
| attr_name = attr_name_list[idx] | |
| if is_int(attr_name): | |
| cur_attr = cur_attr[int(attr_name)] | |
| latest_attr_name = '' | |
| else: | |
| try: | |
| if latest_attr_name != '': | |
| cur_attr = cur_attr.__getattr__(f"{latest_attr_name}.{attr_name}") | |
| else: | |
| cur_attr = cur_attr.__getattr__(attr_name) | |
| latest_attr_name = '' | |
| except Exception as e: | |
| if latest_attr_name != '': | |
| latest_attr_name = f"{latest_attr_name}.{attr_name}" | |
| else: | |
| latest_attr_name = attr_name | |
| up_w = v_lora | |
| down_w = state_dict[k_lora.replace('.lora_B.weight', '.lora_A.weight')] | |
| # 赋值 | |
| einsum_a = f"ijabcdefg" | |
| einsum_b = f"jkabcdefg" | |
| einsum_res = f"ikabcdefg" | |
| length_shape = len(up_w.shape) | |
| einsum_str = f"{einsum_a[:length_shape]},{einsum_b[:length_shape]}->{einsum_res[:length_shape]}" | |
| dtype = cur_attr.weight.data.dtype | |
| d_w = torch.einsum(einsum_str, up_w.to(torch.float32), down_w.to(torch.float32)).to(dtype) | |
| cur_attr.weight.data = cur_attr.weight.data + d_w * lora_weight | |
| # text encoder 部分 | |
| raw_state_dict = load_file(lora_file) | |
| raw_state_dict = {k:v.to(device) for k,v in raw_state_dict.items()} | |
| # text encoder | |
| state_dict = {k:v for k,v in raw_state_dict.items() if 'lora_te1_' in k} | |
| model = self.text_encoder | |
| keys = list(state_dict.keys()) | |
| keys = [k for k in keys if k.startswith('lora_te1_')] | |
| for k_lora in tqdm(keys, total=len(keys), desc=f"loading lora in text_encoder ..."): | |
| v_lora = state_dict[k_lora] | |
| # 非 up 的都跳过 | |
| if '.lora_down.weight' in k_lora: | |
| continue | |
| if '.alpha' in k_lora: | |
| continue | |
| k_lora_name = k_lora.replace("lora_te1_", "") | |
| k_lora_name = k_lora_name.replace(".lora_up.weight", "") | |
| attr_name_list = k_lora_name.split('_') | |
| cur_attr = model | |
| latest_attr_name = '' | |
| for idx in range(0, len(attr_name_list)): | |
| attr_name = attr_name_list[idx] | |
| if is_int(attr_name): | |
| cur_attr = cur_attr[int(attr_name)] | |
| latest_attr_name = '' | |
| else: | |
| try: | |
| if latest_attr_name != '': | |
| cur_attr = cur_attr.__getattr__(f"{latest_attr_name}_{attr_name}") | |
| else: | |
| cur_attr = cur_attr.__getattr__(attr_name) | |
| latest_attr_name = '' | |
| except Exception as e: | |
| if latest_attr_name != '': | |
| latest_attr_name = f"{latest_attr_name}_{attr_name}" | |
| else: | |
| latest_attr_name = attr_name | |
| up_w = v_lora | |
| down_w = state_dict[k_lora.replace('.lora_up.weight', '.lora_down.weight')] | |
| alpha = state_dict.get(k_lora.replace('.lora_up.weight', '.alpha'), None) | |
| if alpha is None: | |
| lora_scale = 1 | |
| else: | |
| rank = up_w.shape[1] | |
| lora_scale = alpha / rank | |
| # 赋值 | |
| einsum_a = f"ijabcdefg" | |
| einsum_b = f"jkabcdefg" | |
| einsum_res = f"ikabcdefg" | |
| length_shape = len(up_w.shape) | |
| einsum_str = f"{einsum_a[:length_shape]},{einsum_b[:length_shape]}->{einsum_res[:length_shape]}" | |
| dtype = cur_attr.weight.data.dtype | |
| d_w = torch.einsum(einsum_str, up_w.to(torch.float32), down_w.to(torch.float32)).to(dtype) | |
| cur_attr.weight.data = cur_attr.weight.data + d_w * lora_scale * lora_weight | |