File size: 9,567 Bytes
8fb99cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import math
import torch
import torchvision.transforms.functional as F
TOKENS = 75
    
def hook_forwards(self, root_module: torch.nn.Module):
    for name, module in root_module.named_modules():
        if "attn" in name and "transformer_blocks" in name  and "single_transformer_blocks" not in name and module.__class__.__name__ == "Attention":
            module.forward = FluxTransformerBlock_hook_forward(self, module)           
        elif "attn" in name and "single_transformer_blocks" in name and module.__class__.__name__ == "Attention":
            module.forward = FluxSingleTransformerBlock_hook_forward(self, module) 

def FluxSingleTransformerBlock_hook_forward(self, module):
    def forward(hidden_states=None, encoder_hidden_states=None, image_rotary_emb=None, SR_encoder_hidden_states_list=None, SR_norm_encoder_hidden_states_list=None, SR_hidden_states_list=None, SR_norm_hidden_states_list=None):
        flux_hidden_states=module.processor(module, hidden_states=hidden_states, image_rotary_emb=image_rotary_emb)

        height = self.h 
        width = self.w
        x_t = hidden_states.size()[1]-512
        scale = round(math.sqrt(height * width / x_t))
        latent_h = round(height / scale)
        latent_w = round(width / scale)
        ha, wa = x_t % latent_h, x_t % latent_w

        if ha == 0:
            latent_w = int(x_t / latent_h)
        elif wa == 0:
            latent_h = int(x_t / latent_w)
        contexts_list = SR_norm_hidden_states_list

        def single_matsepcalc(x, contexts_list, image_rotary_emb):
            h_states = []
            x_t = x.size()[1]-512
            (latent_h,latent_w) = split_dims(x_t, height, width, self)
            latent_out = latent_w
            latent_in = latent_h
            i = 0
            sumout = 0
            SR_all_out_list=[]

            for drow in self.split_ratio:
                v_states = []
                sumin = 0
                for dcell in drow.cols:
                    context = contexts_list[i]
                    i = i + 1 + dcell.breaks
                    SR_all_out = module.processor(module, hidden_states=context, image_rotary_emb=image_rotary_emb)
                    out = SR_all_out[:, 512 :, ...]
                    out = out.reshape(out.size()[0], latent_h, latent_w, out.size()[2])
                    addout = 0
                    addin = 0
                    sumin = sumin + int(latent_in*dcell.end) - int(latent_in*dcell.start)

                    if dcell.end >= 0.999:
                        addin = sumin - latent_in
                        sumout = sumout + int(latent_out*drow.end) - int(latent_out*drow.start)
                        if drow.end >= 0.999:
                            addout = sumout - latent_out
                    out = out[:, int(latent_h*drow.start) + addout:int(latent_h*drow.end),
                                int(latent_w*dcell.start) + addin:int(latent_w*dcell.end), :]

                    v_states.append(out)
                    SR_all_out_list.append(SR_all_out)

                output_x = torch.cat(v_states,dim = 2) 
                h_states.append(output_x)

            output_x = torch.cat(h_states,dim = 1) 
            output_x = output_x.reshape(x.size()[0], x.size()[1]-512, x.size()[2]) 
            new_SR_all_out_list = []

            for SR_all_out in SR_all_out_list:
                SR_all_out[:, 512 :, ...] = output_x
                new_SR_all_out_list.append(SR_all_out)
            x[:, 512 :, ...] = output_x * self.SR_delta + x[:, 512 :, ...] * (1-self.SR_delta)

            return x, new_SR_all_out_list
        
        return single_matsepcalc(flux_hidden_states, contexts_list, image_rotary_emb)
    
    return forward

def FluxTransformerBlock_hook_forward(self, module):
    def forward(hidden_states=None, encoder_hidden_states=None, image_rotary_emb=None, SR_encoder_hidden_states_list=None, SR_norm_encoder_hidden_states_list=None, SR_hidden_states_list=None, SR_norm_hidden_states_list=None):
        flux_hidden_states, flux_encoder_hidden_states = module.processor(module, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, image_rotary_emb=image_rotary_emb)
        
        height = self.h 
        width = self.w
        x_t = hidden_states.size()[1]
        scale = round(math.sqrt(height * width / x_t))
        latent_h = round(height / scale)
        latent_w = round(width / scale)
        ha, wa = x_t % latent_h, x_t % latent_w

        if ha == 0:
            latent_w = int(x_t / latent_h)
        elif wa == 0:
            latent_h = int(x_t / latent_w)

        contexts_list = SR_norm_encoder_hidden_states_list

        def matsepcalc(x, contexts_list, image_rotary_emb):
            h_states = []
            x_t = x.size()[1]
            (latent_h,latent_w) = split_dims(x_t, height, width, self)
            latent_out = latent_w
            latent_in = latent_h
            i = 0
            sumout = 0
            SR_context_attn_output_list = []

            for drow in self.split_ratio:
                v_states = []
                sumin = 0
                for dcell in drow.cols:
                    context = contexts_list[i]
                    i = i + 1 + dcell.breaks
                    out,SR_context_attn_output = module.processor(module, hidden_states=x, encoder_hidden_states=context, image_rotary_emb=image_rotary_emb)
                    out = out.reshape(out.size()[0], latent_h, latent_w, out.size()[2]) 
                    addout = 0
                    addin = 0
                    sumin = sumin + int(latent_in*dcell.end) - int(latent_in*dcell.start)

                    if dcell.end >= 0.999:
                        addin = sumin - latent_in
                        sumout = sumout + int(latent_out*drow.end) - int(latent_out*drow.start)
                        if drow.end >= 0.999:
                            addout = sumout - latent_out

                    out = out[:, int(latent_h*drow.start) + addout:int(latent_h*drow.end),
                                int(latent_w*dcell.start) + addin:int(latent_w*dcell.end), :]
                    v_states.append(out)
                    SR_context_attn_output_list.append(SR_context_attn_output)

                output_x = torch.cat(v_states,dim = 2) 
                h_states.append(output_x)

            output_x = torch.cat(h_states,dim = 1) 
            output_x = output_x.reshape(x.size()[0],x.size()[1],x.size()[2]) 

            return output_x * self.SR_delta + flux_hidden_states * (1-self.SR_delta), flux_encoder_hidden_states, SR_context_attn_output_list

        return matsepcalc(hidden_states, contexts_list, image_rotary_emb)

    return forward

def split_dims(x_t, height, width, self=None):
    """Split an attention layer dimension to height + width.
    The original estimate was latent_h = sqrt(hw_ratio*x_t),
    rounding to the nearest value. However, this proved inaccurate.
    The actual operation seems to be as follows:
    - Divide h,w by 8, rounding DOWN.
    - For every new layer (of 4), divide both by 2 and round UP (then back up).
    - Multiply h*w to yield x_t.
    There is no inverse function to this set of operations,
    so instead we mimic them without the multiplication part using the original h+w.
    It's worth noting that no known checkpoints follow a different system of layering,
    but it's theoretically possible. Please report if encountered.
    """
    scale = math.ceil(math.log2(math.sqrt(height * width / x_t)))
    latent_h = repeat_div(height, scale)
    latent_w = repeat_div(width, scale)
    if x_t > latent_h * latent_w and hasattr(self, "nei_multi"):
        latent_h, latent_w = self.nei_multi[1], self.nei_multi[0] 
        while latent_h * latent_w != x_t:
            latent_h, latent_w = latent_h // 2, latent_w // 2

    return latent_h, latent_w

def repeat_div(x,y):
    """Imitates dimension halving common in convolution operations.
    
    This is a pretty big assumption of the model,
    but then if some model doesn't work like that it will be easy to spot.
    """
    while y > 0:
        x = math.ceil(x / 2)
        y = y - 1
    return x


def init_forwards(self, root_module: torch.nn.Module):
    for name, module in root_module.named_modules():
        if "attn" in name and "transformer_blocks" in name  and "single_transformer_blocks" not in name and module.__class__.__name__ == "Attention":
            module.forward = FluxTransformerBlock_init_forward(self, module)           
        elif "attn" in name and "single_transformer_blocks" in name and module.__class__.__name__ == "Attention":
            module.forward = FluxSingleTransformerBlock_init_forward(self, module) 

def FluxSingleTransformerBlock_init_forward(self, module):
    def forward(hidden_states=None, encoder_hidden_states=None, image_rotary_emb=None,RPG_encoder_hidden_states_list=None,RPG_norm_encoder_hidden_states_list=None,RPG_hidden_states_list=None,RPG_norm_hidden_states_list=None):
        return module.processor(module, hidden_states=hidden_states, image_rotary_emb=image_rotary_emb)
    return forward

def FluxTransformerBlock_init_forward(self, module):
    def forward(hidden_states=None, encoder_hidden_states=None, image_rotary_emb=None,RPG_encoder_hidden_states_list=None,RPG_norm_encoder_hidden_states_list=None,RPG_hidden_states_list=None,RPG_norm_hidden_states_list=None):
        return module.processor(module, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, image_rotary_emb=image_rotary_emb)
    return forward