from diffusers.models.attention_processor import FluxAttnProcessor2_0 from safetensors import safe_open import re import torch from .layers_cache import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor # 移除全局 device = "cuda",改为通过参数传递 def load_safetensors(path): tensors = {} with safe_open(path, framework="pt", device="cpu") as f: for key in f.keys(): tensors[key] = f.get_tensor(key) return tensors def get_lora_rank(checkpoint): for k in checkpoint.keys(): if k.endswith(".down.weight"): return checkpoint[k].shape[0] def load_checkpoint(local_path): if local_path is not None: if '.safetensors' in local_path: print(f"Loading .safetensors checkpoint from {local_path}") checkpoint = load_safetensors(local_path) else: print(f"Loading checkpoint from {local_path}") checkpoint = torch.load(local_path, map_location='cpu') return checkpoint def update_model_with_lora(checkpoint, lora_weights, transformer, cond_size, device="cpu"): number = len(lora_weights) ranks = [get_lora_rank(checkpoint) for _ in range(number)] lora_attn_procs = {} double_blocks_idx = list(range(19)) single_blocks_idx = list(range(38)) for name, attn_processor in transformer.attn_processors.items(): match = re.search(r'\.(\d+)\.', name) if match: layer_index = int(match.group(1)) if name.startswith("transformer_blocks") and layer_index in double_blocks_idx: lora_state_dicts = {} for key, value in checkpoint.items(): if re.search(r'\.(\d+)\.', key): checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1)) if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"): lora_state_dicts[key] = value lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor( dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, # 使用传入的 device 参数 dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number ) # Load weights and move to specified device for n in range(number): lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None) lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None) lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None) lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None) lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None) lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None) lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None) lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None) lora_attn_procs[name].to(device) # 使用传入的 device elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx: lora_state_dicts = {} for key, value in checkpoint.items(): if re.search(r'\.(\d+)\.', key): checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1)) if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"): lora_state_dicts[key] = value lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor( dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, # 使用传入的 device 参数 dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number ) # Load weights and move to specified device for n in range(number): lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None) lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None) lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None) lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None) lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None) lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None) lora_attn_procs[name].to(device) # 使用传入的 device else: lora_attn_procs[name] = FluxAttnProcessor2_0() transformer.set_attn_processor(lora_attn_procs) def update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size, device="cpu"): # 顺便更新此函数 ck_number = len(checkpoints) cond_lora_number = [len(ls) for ls in lora_weights] cond_number = sum(cond_lora_number) ranks = [get_lora_rank(checkpoint) for checkpoint in checkpoints] multi_lora_weight = [] for ls in lora_weights: for n in ls: multi_lora_weight.append(n) lora_attn_procs = {} double_blocks_idx = list(range(19)) single_blocks_idx = list(range(38)) for name, attn_processor in transformer.attn_processors.items(): match = re.search(r'\.(\d+)\.', name) if match: layer_index = int(match.group(1)) if name.startswith("transformer_blocks") and layer_index in double_blocks_idx: lora_state_dicts = [{} for _ in range(ck_number)] for idx, checkpoint in enumerate(checkpoints): for key, value in checkpoint.items(): if re.search(r'\.(\d+)\.', key): checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1)) if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"): lora_state_dicts[idx][key] = value lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor( dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, # 使用传入的 device 参数 dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number ) num = 0 for idx in range(ck_number): for n in range(cond_lora_number[idx]): lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None) lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None) lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None) lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None) lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None) lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None) lora_attn_procs[name].proj_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.down.weight', None) lora_attn_procs[name].proj_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.up.weight', None) lora_attn_procs[name].to(device) # 使用传入的 device num += 1 elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx: lora_state_dicts = [{} for _ in range(ck_number)] for idx, checkpoint in enumerate(checkpoints): for key, value in checkpoint.items(): if re.search(r'\.(\d+)\.', key): checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1)) if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"): lora_state_dicts[idx][key] = value lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor( dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, # 使用传入的 device 参数 dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number ) num = 0 for idx in range(ck_number): for n in range(cond_lora_number[idx]): lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None) lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None) lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None) lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None) lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None) lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None) lora_attn_procs[name].to(device) # 使用传入的 device num += 1 else: lora_attn_procs[name] = FluxAttnProcessor2_0() transformer.set_attn_processor(lora_attn_procs) def set_single_lora(transformer, local_path, lora_weights=[], cond_size=512, device="cpu"): checkpoint = load_checkpoint(local_path) update_model_with_lora(checkpoint, lora_weights, transformer, cond_size, device=device) def set_multi_lora(transformer, local_paths, lora_weights=[[]], cond_size=512, device="cpu"): # 顺便更新此函数 checkpoints = [load_checkpoint(local_path) for local_path in local_paths] update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size, device=device) def unset_lora(transformer): lora_attn_procs = {} for name, attn_processor in transformer.attn_processors.items(): lora_attn_procs[name] = FluxAttnProcessor2_0() transformer.set_attn_processor(lora_attn_procs)