Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
from typing import Optional | |
import torch | |
import comfy.model_management | |
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization | |
class OFTDiff(WeightAdapterTrainBase): | |
def __init__(self, weights): | |
super().__init__() | |
# Unpack weights tuple from LoHaAdapter | |
blocks, rescale, alpha, _ = weights | |
# Create trainable parameters | |
self.oft_blocks = torch.nn.Parameter(blocks) | |
if rescale is not None: | |
self.rescale = torch.nn.Parameter(rescale) | |
self.rescaled = True | |
else: | |
self.rescaled = False | |
self.block_num, self.block_size, _ = blocks.shape | |
self.constraint = float(alpha) | |
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) | |
def __call__(self, w): | |
org_dtype = w.dtype | |
I = torch.eye(self.block_size, device=self.oft_blocks.device) | |
## generate r | |
# for Q = -Q^T | |
q = self.oft_blocks - self.oft_blocks.transpose(1, 2) | |
normed_q = q | |
if self.constraint: | |
q_norm = torch.norm(q) + 1e-8 | |
if q_norm > self.constraint: | |
normed_q = q * self.constraint / q_norm | |
# use float() to prevent unsupported type | |
r = (I + normed_q) @ (I - normed_q).float().inverse() | |
## Apply chunked matmul on weight | |
_, *shape = w.shape | |
org_weight = w.to(dtype=r.dtype) | |
org_weight = org_weight.unflatten(0, (self.block_num, self.block_size)) | |
# Init R=0, so add I on it to ensure the output of step0 is original model output | |
weight = torch.einsum( | |
"k n m, k n ... -> k m ...", | |
r, | |
org_weight, | |
).flatten(0, 1) | |
if self.rescaled: | |
weight = self.rescale * weight | |
return weight.to(org_dtype) | |
def passive_memory_usage(self): | |
"""Calculates memory usage of the trainable parameters.""" | |
return sum(param.numel() * param.element_size() for param in self.parameters()) | |
class OFTAdapter(WeightAdapterBase): | |
name = "oft" | |
def __init__(self, loaded_keys, weights): | |
self.loaded_keys = loaded_keys | |
self.weights = weights | |
def create_train(cls, weight, rank=1, alpha=1.0): | |
out_dim = weight.shape[0] | |
block_size, block_num = factorization(out_dim, rank) | |
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype) | |
return OFTDiff( | |
(block, None, alpha, None) | |
) | |
def to_train(self): | |
return OFTDiff(self.weights) | |
def load( | |
cls, | |
x: str, | |
lora: dict[str, torch.Tensor], | |
alpha: float, | |
dora_scale: torch.Tensor, | |
loaded_keys: set[str] = None, | |
) -> Optional["OFTAdapter"]: | |
if loaded_keys is None: | |
loaded_keys = set() | |
blocks_name = "{}.oft_blocks".format(x) | |
rescale_name = "{}.rescale".format(x) | |
blocks = None | |
if blocks_name in lora.keys(): | |
blocks = lora[blocks_name] | |
if blocks.ndim == 3: | |
loaded_keys.add(blocks_name) | |
else: | |
blocks = None | |
if blocks is None: | |
return None | |
rescale = None | |
if rescale_name in lora.keys(): | |
rescale = lora[rescale_name] | |
loaded_keys.add(rescale_name) | |
weights = (blocks, rescale, alpha, dora_scale) | |
return cls(loaded_keys, weights) | |
def calculate_weight( | |
self, | |
weight, | |
key, | |
strength, | |
strength_model, | |
offset, | |
function, | |
intermediate_dtype=torch.float32, | |
original_weight=None, | |
): | |
v = self.weights | |
blocks = v[0] | |
rescale = v[1] | |
alpha = v[2] | |
if alpha is None: | |
alpha = 0 | |
dora_scale = v[3] | |
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) | |
if rescale is not None: | |
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype) | |
block_num, block_size, *_ = blocks.shape | |
try: | |
# Get r | |
I = torch.eye(block_size, device=blocks.device, dtype=blocks.dtype) | |
# for Q = -Q^T | |
q = blocks - blocks.transpose(1, 2) | |
normed_q = q | |
if alpha > 0: # alpha in oft/boft is for constraint | |
q_norm = torch.norm(q) + 1e-8 | |
if q_norm > alpha: | |
normed_q = q * alpha / q_norm | |
# use float() to prevent unsupported type in .inverse() | |
r = (I + normed_q) @ (I - normed_q).float().inverse() | |
r = r.to(weight) | |
_, *shape = weight.shape | |
lora_diff = torch.einsum( | |
"k n m, k n ... -> k m ...", | |
(r * strength) - strength * I, | |
weight.view(block_num, block_size, *shape), | |
).view(-1, *shape) | |
if dora_scale is not None: | |
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) | |
else: | |
weight += function((strength * lora_diff).type(weight.dtype)) | |
except Exception as e: | |
logging.error("ERROR {} {} {}".format(self.name, key, e)) | |
return weight | |