Spaces:
Build error
Build error
| from typing import * | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.linalg as linalg | |
| from tqdm import tqdm | |
| def make_sparse(t: torch.Tensor, sparsity=0.95): | |
| abs_t = torch.abs(t) | |
| np_array = abs_t.detach().cpu().numpy() | |
| quan = float(np.quantile(np_array, sparsity)) | |
| sparse_t = t.masked_fill(abs_t < quan, 0) | |
| return sparse_t | |
| def extract_conv( | |
| weight: Union[torch.Tensor, nn.Parameter], | |
| mode = 'fixed', | |
| mode_param = 0, | |
| device = 'cpu', | |
| is_cp = False, | |
| ) -> Tuple[nn.Parameter, nn.Parameter]: | |
| weight = weight.to(device) | |
| out_ch, in_ch, kernel_size, _ = weight.shape | |
| U, S, Vh = linalg.svd(weight.reshape(out_ch, -1)) | |
| if mode=='fixed': | |
| lora_rank = mode_param | |
| elif mode=='threshold': | |
| assert mode_param>=0 | |
| lora_rank = torch.sum(S>mode_param) | |
| elif mode=='ratio': | |
| assert 1>=mode_param>=0 | |
| min_s = torch.max(S)*mode_param | |
| lora_rank = torch.sum(S>min_s) | |
| elif mode=='quantile' or mode=='percentile': | |
| assert 1>=mode_param>=0 | |
| s_cum = torch.cumsum(S, dim=0) | |
| min_cum_sum = mode_param * torch.sum(S) | |
| lora_rank = torch.sum(s_cum<min_cum_sum) | |
| else: | |
| raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"') | |
| lora_rank = max(1, lora_rank) | |
| lora_rank = min(out_ch, in_ch, lora_rank) | |
| if lora_rank>=out_ch/2 and not is_cp: | |
| return weight, 'full' | |
| U = U[:, :lora_rank] | |
| S = S[:lora_rank] | |
| U = U @ torch.diag(S) | |
| Vh = Vh[:lora_rank, :] | |
| diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach() | |
| extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach() | |
| extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach() | |
| del U, S, Vh, weight | |
| return (extract_weight_A, extract_weight_B, diff), 'low rank' | |
| def extract_linear( | |
| weight: Union[torch.Tensor, nn.Parameter], | |
| mode = 'fixed', | |
| mode_param = 0, | |
| device = 'cpu', | |
| ) -> Tuple[nn.Parameter, nn.Parameter]: | |
| weight = weight.to(device) | |
| out_ch, in_ch = weight.shape | |
| U, S, Vh = linalg.svd(weight) | |
| if mode=='fixed': | |
| lora_rank = mode_param | |
| elif mode=='threshold': | |
| assert mode_param>=0 | |
| lora_rank = torch.sum(S>mode_param) | |
| elif mode=='ratio': | |
| assert 1>=mode_param>=0 | |
| min_s = torch.max(S)*mode_param | |
| lora_rank = torch.sum(S>min_s) | |
| elif mode=='quantile' or mode=='percentile': | |
| assert 1>=mode_param>=0 | |
| s_cum = torch.cumsum(S, dim=0) | |
| min_cum_sum = mode_param * torch.sum(S) | |
| lora_rank = torch.sum(s_cum<min_cum_sum) | |
| else: | |
| raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"') | |
| lora_rank = max(1, lora_rank) | |
| lora_rank = min(out_ch, in_ch, lora_rank) | |
| if lora_rank>=out_ch/2: | |
| return weight, 'full' | |
| U = U[:, :lora_rank] | |
| S = S[:lora_rank] | |
| U = U @ torch.diag(S) | |
| Vh = Vh[:lora_rank, :] | |
| diff = (weight - U @ Vh).detach() | |
| extract_weight_A = Vh.reshape(lora_rank, in_ch).detach() | |
| extract_weight_B = U.reshape(out_ch, lora_rank).detach() | |
| del U, S, Vh, weight | |
| return (extract_weight_A, extract_weight_B, diff), 'low rank' | |
| def extract_diff( | |
| base_model, | |
| db_model, | |
| mode = 'fixed', | |
| linear_mode_param = 0, | |
| conv_mode_param = 0, | |
| extract_device = 'cpu', | |
| use_bias = False, | |
| sparsity = 0.98, | |
| small_conv = True | |
| ): | |
| UNET_TARGET_REPLACE_MODULE = [ | |
| "Transformer2DModel", | |
| "Attention", | |
| "ResnetBlock2D", | |
| "Downsample2D", | |
| "Upsample2D" | |
| ] | |
| UNET_TARGET_REPLACE_NAME = [ | |
| "conv_in", | |
| "conv_out", | |
| "time_embedding.linear_1", | |
| "time_embedding.linear_2", | |
| ] | |
| TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] | |
| LORA_PREFIX_UNET = 'lora_unet' | |
| LORA_PREFIX_TEXT_ENCODER = 'lora_te' | |
| def make_state_dict( | |
| prefix, | |
| root_module: torch.nn.Module, | |
| target_module: torch.nn.Module, | |
| target_replace_modules, | |
| target_replace_names = [] | |
| ): | |
| loras = {} | |
| temp = {} | |
| temp_name = {} | |
| for name, module in root_module.named_modules(): | |
| if module.__class__.__name__ in target_replace_modules: | |
| temp[name] = {} | |
| for child_name, child_module in module.named_modules(): | |
| if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}: | |
| continue | |
| temp[name][child_name] = child_module.weight | |
| elif name in target_replace_names: | |
| temp_name[name] = module.weight | |
| for name, module in tqdm(list(target_module.named_modules())): | |
| if name in temp: | |
| weights = temp[name] | |
| for child_name, child_module in module.named_modules(): | |
| lora_name = prefix + '.' + name + '.' + child_name | |
| lora_name = lora_name.replace('.', '_') | |
| layer = child_module.__class__.__name__ | |
| if layer in {'Linear', 'Conv2d'}: | |
| root_weight = child_module.weight | |
| if torch.allclose(root_weight, weights[child_name]): | |
| continue | |
| if layer == 'Linear': | |
| weight, decompose_mode = extract_linear( | |
| (child_module.weight - weights[child_name]), | |
| mode, | |
| linear_mode_param, | |
| device = extract_device, | |
| ) | |
| if decompose_mode == 'low rank': | |
| extract_a, extract_b, diff = weight | |
| elif layer == 'Conv2d': | |
| is_linear = (child_module.weight.shape[2] == 1 | |
| and child_module.weight.shape[3] == 1) | |
| weight, decompose_mode = extract_conv( | |
| (child_module.weight - weights[child_name]), | |
| mode, | |
| linear_mode_param if is_linear else conv_mode_param, | |
| device = extract_device, | |
| ) | |
| if decompose_mode == 'low rank': | |
| extract_a, extract_b, diff = weight | |
| if small_conv and not is_linear and decompose_mode == 'low rank': | |
| dim = extract_a.size(0) | |
| (extract_c, extract_a, _), _ = extract_conv( | |
| extract_a.transpose(0, 1), | |
| 'fixed', dim, | |
| extract_device, True | |
| ) | |
| extract_a = extract_a.transpose(0, 1) | |
| extract_c = extract_c.transpose(0, 1) | |
| loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half() | |
| diff = child_module.weight - torch.einsum( | |
| 'i j k l, j r, p i -> p r k l', | |
| extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1) | |
| ).detach().cpu().contiguous() | |
| del extract_c | |
| else: | |
| continue | |
| if decompose_mode == 'low rank': | |
| loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half() | |
| loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half() | |
| loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half() | |
| if use_bias: | |
| diff = diff.detach().cpu().reshape(extract_b.size(0), -1) | |
| sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce() | |
| indices = sparse_diff.indices().to(torch.int16) | |
| values = sparse_diff.values().half() | |
| loras[f'{lora_name}.bias_indices'] = indices | |
| loras[f'{lora_name}.bias_values'] = values | |
| loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16) | |
| del extract_a, extract_b, diff | |
| elif decompose_mode == 'full': | |
| loras[f'{lora_name}.diff'] = weight.detach().cpu().contiguous().half() | |
| else: | |
| raise NotImplementedError | |
| elif name in temp_name: | |
| weights = temp_name[name] | |
| lora_name = prefix + '.' + name | |
| lora_name = lora_name.replace('.', '_') | |
| layer = module.__class__.__name__ | |
| if layer in {'Linear', 'Conv2d'}: | |
| root_weight = module.weight | |
| if torch.allclose(root_weight, weights): | |
| continue | |
| if layer == 'Linear': | |
| weight, decompose_mode = extract_linear( | |
| (root_weight - weights), | |
| mode, | |
| linear_mode_param, | |
| device = extract_device, | |
| ) | |
| if decompose_mode == 'low rank': | |
| extract_a, extract_b, diff = weight | |
| elif layer == 'Conv2d': | |
| is_linear = ( | |
| root_weight.shape[2] == 1 | |
| and root_weight.shape[3] == 1 | |
| ) | |
| weight, decompose_mode = extract_conv( | |
| (root_weight - weights), | |
| mode, | |
| linear_mode_param if is_linear else conv_mode_param, | |
| device = extract_device, | |
| ) | |
| if decompose_mode == 'low rank': | |
| extract_a, extract_b, diff = weight | |
| if small_conv and not is_linear and decompose_mode == 'low rank': | |
| dim = extract_a.size(0) | |
| (extract_c, extract_a, _), _ = extract_conv( | |
| extract_a.transpose(0, 1), | |
| 'fixed', dim, | |
| extract_device, True | |
| ) | |
| extract_a = extract_a.transpose(0, 1) | |
| extract_c = extract_c.transpose(0, 1) | |
| loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half() | |
| diff = root_weight - torch.einsum( | |
| 'i j k l, j r, p i -> p r k l', | |
| extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1) | |
| ).detach().cpu().contiguous() | |
| del extract_c | |
| else: | |
| continue | |
| if decompose_mode == 'low rank': | |
| loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half() | |
| loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half() | |
| loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half() | |
| if use_bias: | |
| diff = diff.detach().cpu().reshape(extract_b.size(0), -1) | |
| sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce() | |
| indices = sparse_diff.indices().to(torch.int16) | |
| values = sparse_diff.values().half() | |
| loras[f'{lora_name}.bias_indices'] = indices | |
| loras[f'{lora_name}.bias_values'] = values | |
| loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16) | |
| del extract_a, extract_b, diff | |
| elif decompose_mode == 'full': | |
| loras[f'{lora_name}.diff'] = weight.detach().cpu().contiguous().half() | |
| else: | |
| raise NotImplementedError | |
| return loras | |
| text_encoder_loras = make_state_dict( | |
| LORA_PREFIX_TEXT_ENCODER, | |
| base_model[0], db_model[0], | |
| TEXT_ENCODER_TARGET_REPLACE_MODULE | |
| ) | |
| unet_loras = make_state_dict( | |
| LORA_PREFIX_UNET, | |
| base_model[2], db_model[2], | |
| UNET_TARGET_REPLACE_MODULE, | |
| UNET_TARGET_REPLACE_NAME | |
| ) | |
| print(len(text_encoder_loras), len(unet_loras)) | |
| return text_encoder_loras|unet_loras | |
| def get_module( | |
| lyco_state_dict: Dict, | |
| lora_name | |
| ): | |
| if f'{lora_name}.lora_up.weight' in lyco_state_dict: | |
| up = lyco_state_dict[f'{lora_name}.lora_up.weight'] | |
| down = lyco_state_dict[f'{lora_name}.lora_down.weight'] | |
| mid = lyco_state_dict.get(f'{lora_name}.lora_mid.weight', None) | |
| alpha = lyco_state_dict.get(f'{lora_name}.alpha', None) | |
| return 'locon', (up, down, mid, alpha) | |
| elif f'{lora_name}.hada_w1_a' in lyco_state_dict: | |
| w1a = lyco_state_dict[f'{lora_name}.hada_w1_a'] | |
| w1b = lyco_state_dict[f'{lora_name}.hada_w1_b'] | |
| w2a = lyco_state_dict[f'{lora_name}.hada_w2_a'] | |
| w2b = lyco_state_dict[f'{lora_name}.hada_w2_b'] | |
| t1 = lyco_state_dict.get(f'{lora_name}.hada_t1', None) | |
| t2 = lyco_state_dict.get(f'{lora_name}.hada_t2', None) | |
| alpha = lyco_state_dict.get(f'{lora_name}.alpha', None) | |
| return 'hada', (w1a, w1b, w2a, w2b, t1, t2, alpha) | |
| elif f'{lora_name}.weight' in lyco_state_dict: | |
| weight = lyco_state_dict[f'{lora_name}.weight'] | |
| on_input = lyco_state_dict.get(f'{lora_name}.on_input', False) | |
| return 'ia3', (weight, on_input) | |
| elif (f'{lora_name}.lokr_w1' in lyco_state_dict | |
| or f'{lora_name}.lokr_w1_a' in lyco_state_dict): | |
| w1 = lyco_state_dict.get(f'{lora_name}.lokr_w1', None) | |
| w1a = lyco_state_dict.get(f'{lora_name}.lokr_w1_a', None) | |
| w1b = lyco_state_dict.get(f'{lora_name}.lokr_w1_b', None) | |
| w2 = lyco_state_dict.get(f'{lora_name}.lokr_w2', None) | |
| w2a = lyco_state_dict.get(f'{lora_name}.lokr_w2_a', None) | |
| w2b = lyco_state_dict.get(f'{lora_name}.lokr_w2_b', None) | |
| t1 = lyco_state_dict.get(f'{lora_name}.lokr_t1', None) | |
| t2 = lyco_state_dict.get(f'{lora_name}.lokr_t2', None) | |
| alpha = lyco_state_dict.get(f'{lora_name}.alpha', None) | |
| return 'kron', (w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha) | |
| elif f'{lora_name}.diff' in lyco_state_dict: | |
| return 'full', lyco_state_dict[f'{lora_name}.diff'] | |
| else: | |
| return 'None', () | |
| def cp_weight_from_conv( | |
| up, down, mid | |
| ): | |
| up = up.reshape(up.size(0), up.size(1)) | |
| down = down.reshape(down.size(0), down.size(1)) | |
| return torch.einsum('m n w h, i m, n j -> i j w h', mid, up, down) | |
| def cp_weight( | |
| wa, wb, t | |
| ): | |
| temp = torch.einsum('i j k l, j r -> i r k l', t, wb) | |
| return torch.einsum('i j k l, i r -> r j k l', temp, wa) | |
| def rebuild_weight(module_type, params, orig_weight, scale=1): | |
| if orig_weight is None: | |
| return orig_weight | |
| merged = orig_weight | |
| if module_type == 'locon': | |
| up, down, mid, alpha = params | |
| if alpha is not None: | |
| scale *= alpha/up.size(1) | |
| if mid is not None: | |
| rebuild = cp_weight_from_conv(up, down, mid) | |
| else: | |
| rebuild = up.reshape(up.size(0),-1) @ down.reshape(down.size(0), -1) | |
| merged = orig_weight + rebuild.reshape(orig_weight.shape) * scale | |
| del up, down, mid, alpha, params, rebuild | |
| elif module_type == 'hada': | |
| w1a, w1b, w2a, w2b, t1, t2, alpha = params | |
| if alpha is not None: | |
| scale *= alpha / w1b.size(0) | |
| if t1 is not None: | |
| rebuild1 = cp_weight(w1a, w1b, t1) | |
| else: | |
| rebuild1 = w1a @ w1b | |
| if t2 is not None: | |
| rebuild2 = cp_weight(w2a, w2b, t2) | |
| else: | |
| rebuild2 = w2a @ w2b | |
| rebuild = (rebuild1 * rebuild2).reshape(orig_weight.shape) | |
| merged = orig_weight + rebuild * scale | |
| del w1a, w1b, w2a, w2b, t1, t2, alpha, params, rebuild, rebuild1, rebuild2 | |
| elif module_type == 'ia3': | |
| weight, on_input = params | |
| if not on_input: | |
| weight = weight.reshape(-1, 1) | |
| merged = orig_weight + weight * orig_weight * scale | |
| del weight, on_input, params | |
| elif module_type == 'kron': | |
| w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha = params | |
| if alpha is not None and (w1b is not None or w2b is not None): | |
| scale *= alpha / (w1b.size(0) if w1b else w2b.size(0)) | |
| if w1a is not None and w1b is not None: | |
| if t1: | |
| w1 = cp_weight(w1a, w1b, t1) | |
| else: | |
| w1 = w1a @ w1b | |
| if w2a is not None and w2b is not None: | |
| if t2: | |
| w2 = cp_weight(w2a, w2b, t2) | |
| else: | |
| w2 = w2a @ w2b | |
| rebuild = torch.kron(w1, w2).reshape(orig_weight.shape) | |
| merged = orig_weight + rebuild* scale | |
| del w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha, params, rebuild | |
| elif module_type == 'full': | |
| rebuild = params.reshape(orig_weight.shape) | |
| merged = orig_weight + rebuild * scale | |
| del params, rebuild | |
| return merged | |
| def merge( | |
| base_model, | |
| lyco_state_dict, | |
| scale: float = 1.0, | |
| device = 'cpu' | |
| ): | |
| UNET_TARGET_REPLACE_MODULE = [ | |
| "Transformer2DModel", | |
| "Attention", | |
| "ResnetBlock2D", | |
| "Downsample2D", | |
| "Upsample2D" | |
| ] | |
| UNET_TARGET_REPLACE_NAME = [ | |
| "conv_in", | |
| "conv_out", | |
| "time_embedding.linear_1", | |
| "time_embedding.linear_2", | |
| ] | |
| TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] | |
| LORA_PREFIX_UNET = 'lora_unet' | |
| LORA_PREFIX_TEXT_ENCODER = 'lora_te' | |
| merged = 0 | |
| def merge_state_dict( | |
| prefix, | |
| root_module: torch.nn.Module, | |
| lyco_state_dict: Dict[str,torch.Tensor], | |
| target_replace_modules, | |
| target_replace_names = [] | |
| ): | |
| nonlocal merged | |
| for name, module in tqdm(list(root_module.named_modules()), desc=f'Merging {prefix}'): | |
| if module.__class__.__name__ in target_replace_modules: | |
| for child_name, child_module in module.named_modules(): | |
| if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}: | |
| continue | |
| lora_name = prefix + '.' + name + '.' + child_name | |
| lora_name = lora_name.replace('.', '_') | |
| result = rebuild_weight(*get_module( | |
| lyco_state_dict, lora_name | |
| ), getattr(child_module, 'weight'), scale) | |
| if result is not None: | |
| merged += 1 | |
| child_module.requires_grad_(False) | |
| child_module.weight.copy_(result) | |
| elif name in target_replace_names: | |
| lora_name = prefix + '.' + name | |
| lora_name = lora_name.replace('.', '_') | |
| result = rebuild_weight(*get_module( | |
| lyco_state_dict, lora_name | |
| ), getattr(module, 'weight'), scale) | |
| if result is not None: | |
| merged += 1 | |
| module.requires_grad_(False) | |
| module.weight.copy_(result) | |
| if device == 'cpu': | |
| for k, v in tqdm(list(lyco_state_dict.items()), desc='Converting Dtype'): | |
| lyco_state_dict[k] = v.float() | |
| merge_state_dict( | |
| LORA_PREFIX_TEXT_ENCODER, | |
| base_model[0], | |
| lyco_state_dict, | |
| TEXT_ENCODER_TARGET_REPLACE_MODULE, | |
| UNET_TARGET_REPLACE_NAME | |
| ) | |
| merge_state_dict( | |
| LORA_PREFIX_UNET, | |
| base_model[2], | |
| lyco_state_dict, | |
| UNET_TARGET_REPLACE_MODULE, | |
| UNET_TARGET_REPLACE_NAME | |
| ) | |
| print(f'{merged} Modules been merged') |