Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| import torchvision.transforms.functional as F | |
| TOKENS = 75 | |
| def hook_forwards(self, root_module: torch.nn.Module): | |
| for name, module in root_module.named_modules(): | |
| if "attn" in name and "transformer_blocks" in name and "single_transformer_blocks" not in name and module.__class__.__name__ == "Attention": | |
| module.forward = FluxTransformerBlock_hook_forward(self, module) | |
| elif "attn" in name and "single_transformer_blocks" in name and module.__class__.__name__ == "Attention": | |
| module.forward = FluxSingleTransformerBlock_hook_forward(self, module) | |
| def FluxSingleTransformerBlock_hook_forward(self, module): | |
| def forward(hidden_states=None, encoder_hidden_states=None, image_rotary_emb=None, SR_encoder_hidden_states_list=None, SR_norm_encoder_hidden_states_list=None, SR_hidden_states_list=None, SR_norm_hidden_states_list=None): | |
| flux_hidden_states=module.processor(module, hidden_states=hidden_states, image_rotary_emb=image_rotary_emb) | |
| height = self.h | |
| width = self.w | |
| x_t = hidden_states.size()[1]-512 | |
| scale = round(math.sqrt(height * width / x_t)) | |
| latent_h = round(height / scale) | |
| latent_w = round(width / scale) | |
| ha, wa = x_t % latent_h, x_t % latent_w | |
| if ha == 0: | |
| latent_w = int(x_t / latent_h) | |
| elif wa == 0: | |
| latent_h = int(x_t / latent_w) | |
| contexts_list = SR_norm_hidden_states_list | |
| def single_matsepcalc(x, contexts_list, image_rotary_emb): | |
| h_states = [] | |
| x_t = x.size()[1]-512 | |
| (latent_h,latent_w) = split_dims(x_t, height, width, self) | |
| latent_out = latent_w | |
| latent_in = latent_h | |
| i = 0 | |
| sumout = 0 | |
| SR_all_out_list=[] | |
| for drow in self.split_ratio: | |
| v_states = [] | |
| sumin = 0 | |
| for dcell in drow.cols: | |
| context = contexts_list[i] | |
| i = i + 1 + dcell.breaks | |
| SR_all_out = module.processor(module, hidden_states=context, image_rotary_emb=image_rotary_emb) | |
| out = SR_all_out[:, 512 :, ...] | |
| out = out.reshape(out.size()[0], latent_h, latent_w, out.size()[2]) | |
| addout = 0 | |
| addin = 0 | |
| sumin = sumin + int(latent_in*dcell.end) - int(latent_in*dcell.start) | |
| if dcell.end >= 0.999: | |
| addin = sumin - latent_in | |
| sumout = sumout + int(latent_out*drow.end) - int(latent_out*drow.start) | |
| if drow.end >= 0.999: | |
| addout = sumout - latent_out | |
| out = out[:, int(latent_h*drow.start) + addout:int(latent_h*drow.end), | |
| int(latent_w*dcell.start) + addin:int(latent_w*dcell.end), :] | |
| v_states.append(out) | |
| SR_all_out_list.append(SR_all_out) | |
| output_x = torch.cat(v_states,dim = 2) | |
| h_states.append(output_x) | |
| output_x = torch.cat(h_states,dim = 1) | |
| output_x = output_x.reshape(x.size()[0], x.size()[1]-512, x.size()[2]) | |
| new_SR_all_out_list = [] | |
| for SR_all_out in SR_all_out_list: | |
| SR_all_out[:, 512 :, ...] = output_x | |
| new_SR_all_out_list.append(SR_all_out) | |
| x[:, 512 :, ...] = output_x * self.SR_delta + x[:, 512 :, ...] * (1-self.SR_delta) | |
| return x, new_SR_all_out_list | |
| return single_matsepcalc(flux_hidden_states, contexts_list, image_rotary_emb) | |
| return forward | |
| def FluxTransformerBlock_hook_forward(self, module): | |
| def forward(hidden_states=None, encoder_hidden_states=None, image_rotary_emb=None, SR_encoder_hidden_states_list=None, SR_norm_encoder_hidden_states_list=None, SR_hidden_states_list=None, SR_norm_hidden_states_list=None): | |
| flux_hidden_states, flux_encoder_hidden_states = module.processor(module, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, image_rotary_emb=image_rotary_emb) | |
| height = self.h | |
| width = self.w | |
| x_t = hidden_states.size()[1] | |
| scale = round(math.sqrt(height * width / x_t)) | |
| latent_h = round(height / scale) | |
| latent_w = round(width / scale) | |
| ha, wa = x_t % latent_h, x_t % latent_w | |
| if ha == 0: | |
| latent_w = int(x_t / latent_h) | |
| elif wa == 0: | |
| latent_h = int(x_t / latent_w) | |
| contexts_list = SR_norm_encoder_hidden_states_list | |
| def matsepcalc(x, contexts_list, image_rotary_emb): | |
| h_states = [] | |
| x_t = x.size()[1] | |
| (latent_h,latent_w) = split_dims(x_t, height, width, self) | |
| latent_out = latent_w | |
| latent_in = latent_h | |
| i = 0 | |
| sumout = 0 | |
| SR_context_attn_output_list = [] | |
| for drow in self.split_ratio: | |
| v_states = [] | |
| sumin = 0 | |
| for dcell in drow.cols: | |
| context = contexts_list[i] | |
| i = i + 1 + dcell.breaks | |
| out,SR_context_attn_output = module.processor(module, hidden_states=x, encoder_hidden_states=context, image_rotary_emb=image_rotary_emb) | |
| out = out.reshape(out.size()[0], latent_h, latent_w, out.size()[2]) | |
| addout = 0 | |
| addin = 0 | |
| sumin = sumin + int(latent_in*dcell.end) - int(latent_in*dcell.start) | |
| if dcell.end >= 0.999: | |
| addin = sumin - latent_in | |
| sumout = sumout + int(latent_out*drow.end) - int(latent_out*drow.start) | |
| if drow.end >= 0.999: | |
| addout = sumout - latent_out | |
| out = out[:, int(latent_h*drow.start) + addout:int(latent_h*drow.end), | |
| int(latent_w*dcell.start) + addin:int(latent_w*dcell.end), :] | |
| v_states.append(out) | |
| SR_context_attn_output_list.append(SR_context_attn_output) | |
| output_x = torch.cat(v_states,dim = 2) | |
| h_states.append(output_x) | |
| output_x = torch.cat(h_states,dim = 1) | |
| output_x = output_x.reshape(x.size()[0],x.size()[1],x.size()[2]) | |
| return output_x * self.SR_delta + flux_hidden_states * (1-self.SR_delta), flux_encoder_hidden_states, SR_context_attn_output_list | |
| return matsepcalc(hidden_states, contexts_list, image_rotary_emb) | |
| return forward | |
| def split_dims(x_t, height, width, self=None): | |
| """Split an attention layer dimension to height + width. | |
| The original estimate was latent_h = sqrt(hw_ratio*x_t), | |
| rounding to the nearest value. However, this proved inaccurate. | |
| The actual operation seems to be as follows: | |
| - Divide h,w by 8, rounding DOWN. | |
| - For every new layer (of 4), divide both by 2 and round UP (then back up). | |
| - Multiply h*w to yield x_t. | |
| There is no inverse function to this set of operations, | |
| so instead we mimic them without the multiplication part using the original h+w. | |
| It's worth noting that no known checkpoints follow a different system of layering, | |
| but it's theoretically possible. Please report if encountered. | |
| """ | |
| scale = math.ceil(math.log2(math.sqrt(height * width / x_t))) | |
| latent_h = repeat_div(height, scale) | |
| latent_w = repeat_div(width, scale) | |
| if x_t > latent_h * latent_w and hasattr(self, "nei_multi"): | |
| latent_h, latent_w = self.nei_multi[1], self.nei_multi[0] | |
| while latent_h * latent_w != x_t: | |
| latent_h, latent_w = latent_h // 2, latent_w // 2 | |
| return latent_h, latent_w | |
| def repeat_div(x,y): | |
| """Imitates dimension halving common in convolution operations. | |
| This is a pretty big assumption of the model, | |
| but then if some model doesn't work like that it will be easy to spot. | |
| """ | |
| while y > 0: | |
| x = math.ceil(x / 2) | |
| y = y - 1 | |
| return x | |
| def init_forwards(self, root_module: torch.nn.Module): | |
| for name, module in root_module.named_modules(): | |
| if "attn" in name and "transformer_blocks" in name and "single_transformer_blocks" not in name and module.__class__.__name__ == "Attention": | |
| module.forward = FluxTransformerBlock_init_forward(self, module) | |
| elif "attn" in name and "single_transformer_blocks" in name and module.__class__.__name__ == "Attention": | |
| module.forward = FluxSingleTransformerBlock_init_forward(self, module) | |
| def FluxSingleTransformerBlock_init_forward(self, module): | |
| def forward(hidden_states=None, encoder_hidden_states=None, image_rotary_emb=None,RPG_encoder_hidden_states_list=None,RPG_norm_encoder_hidden_states_list=None,RPG_hidden_states_list=None,RPG_norm_hidden_states_list=None): | |
| return module.processor(module, hidden_states=hidden_states, image_rotary_emb=image_rotary_emb) | |
| return forward | |
| def FluxTransformerBlock_init_forward(self, module): | |
| def forward(hidden_states=None, encoder_hidden_states=None, image_rotary_emb=None,RPG_encoder_hidden_states_list=None,RPG_norm_encoder_hidden_states_list=None,RPG_hidden_states_list=None,RPG_norm_hidden_states_list=None): | |
| return module.processor(module, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, image_rotary_emb=image_rotary_emb) | |
| return forward |