File size: 3,613 Bytes
5c4b5eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
from merge import TokenMergeAttentionProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor, AttnProcessor
import torch.nn.functional as F

if is_xformers_available():
    xformers_is_available = True
else:
    xformers_is_available = False

if hasattr(F, "scaled_dot_product_attention"):
    torch2_is_available = True
else:
    torch2_is_available = False


def hook_tome_model(model: torch.nn.Module):
    """ Adds a forward pre hook to get the image size. This hook can be removed with remove_patch. """

    def hook(module, args):
        module._tome_info["size"] = (args[0].shape[2], args[0].shape[3])
        module._tome_info["timestep"] = args[1].item()
        return None

    model._tome_info["hooks"].append(model.register_forward_pre_hook(hook))


def patch_attention_proc(unet, token_merge_args={}):
    unet._tome_info = {
        "size": None,
        "timestep": None,
        "hooks": [],
        "args": {
            "ratio": token_merge_args.get("ratio", 0.5),  # ratio of tokens to merge
            "sx": token_merge_args.get("sx", 2),  # stride x for sim calculation
            "sy": token_merge_args.get("sy", 2),  # stride y for sim calculation
            "use_rand": token_merge_args.get("use_rand", True),
            "generator": None,

            "merge_tokens": token_merge_args.get("merge_tokens", "keys/values"),  # ["all", "keys/values"]
            "merge_method": token_merge_args.get("merge_method", "downsample"),  # ["none","similarity", "downsample"]
            "downsample_method": token_merge_args.get("downsample_method", "nearest-exact"),
            # native torch interpolation methods ["nearest", "linear", "bilinear", "bicubic", "nearest-exact"]
            "downsample_factor": token_merge_args.get("downsample_factor", 2),  # amount to downsample by
            "timestep_threshold_switch": token_merge_args.get("timestep_threshold_switch", 0.2),
            # timestep to switch to secondary method, 0.2 means 20% steps remaining
            "timestep_threshold_stop": token_merge_args.get("timestep_threshold_stop", 0.0),
            # timestep to stop merging, 0.0 means stop at 0 steps remaining
            "secondary_merge_method": token_merge_args.get("secondary_merge_method", "similarity"),
            # ["none", "similarity", "downsample"]

            "downsample_factor_level_2": token_merge_args.get("downsample_factor_level_2", 1), # amount to downsample by at the 2nd down block of unet
            "ratio_level_2": token_merge_args.get("ratio_level_2", 0.5), # ratio of tokens to merge at the 2nd down block of unet
        }
    }
    hook_tome_model(unet)
    attn_modules = [module for name, module in unet.named_modules() if module.__class__.__name__ == 'BasicTransformerBlock']

    for i, module in enumerate(attn_modules):
        module.attn1.processor = TokenMergeAttentionProcessor()
        module.attn1.processor._tome_info = unet._tome_info


def remove_patch(pipe: torch.nn.Module):
    """ Removes a patch from a ToMe Diffusion module if it was already patched. """

    # this will remove our custom class
    if torch2_is_available:
        for n,m in pipe.unet.named_modules():
            if hasattr(m, "processor"):
                m.processor = AttnProcessor2_0()

    elif xformers_is_available:
        pipe.enable_xformers_memory_efficient_attention()

    else:
        for n,m in pipe.unet.named_modules():
            if hasattr(m, "processor"):
                m.processor = AttnProcessor()