Spaces:
Runtime error
Runtime error
| import abc | |
| LOW_RESOURCE = False | |
| import torch | |
| import cv2 | |
| import torch | |
| import os | |
| import numpy as np | |
| from collections import defaultdict | |
| from functools import partial | |
| from typing import Any, Dict, Optional | |
| def register_attention_control(unet, config=None): | |
| def BasicTransformerBlock_forward( | |
| self, | |
| hidden_states: torch.FloatTensor, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
| timestep: Optional[torch.LongTensor] = None, | |
| cross_attention_kwargs: Dict[str, Any] = None, | |
| class_labels: Optional[torch.LongTensor] = None, | |
| added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, | |
| ) -> torch.FloatTensor: | |
| # Notice that normalization is always applied before the real computation in the following blocks. | |
| # 0. Self-Attention | |
| batch_size = hidden_states.shape[0] | |
| if self.norm_type == "ada_norm": | |
| norm_hidden_states = self.norm1(hidden_states, timestep) | |
| elif self.norm_type == "ada_norm_zero": | |
| norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( | |
| hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype | |
| ) | |
| elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: | |
| norm_hidden_states = self.norm1(hidden_states) | |
| elif self.norm_type == "ada_norm_continuous": | |
| norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) | |
| elif self.norm_type == "ada_norm_single": | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( | |
| self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) | |
| ).chunk(6, dim=1) | |
| norm_hidden_states = self.norm1(hidden_states) | |
| norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa | |
| norm_hidden_states = norm_hidden_states.squeeze(1) | |
| else: | |
| raise ValueError("Incorrect norm used") | |
| # save the origin_hidden_states w/o pos_embed, for the use of motion v embedding | |
| origin_hidden_states = None | |
| if self.pos_embed is not None or hasattr(self.attn1,'vSpatial'): | |
| origin_hidden_states = norm_hidden_states.clone() | |
| if cross_attention_kwargs is None: | |
| cross_attention_kwargs = {} | |
| cross_attention_kwargs["origin_hidden_states"] = origin_hidden_states | |
| if self.pos_embed is not None: | |
| norm_hidden_states = self.pos_embed(norm_hidden_states) | |
| # 1. Retrieve lora scale. | |
| lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 | |
| # 2. Prepare GLIGEN inputs | |
| cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} | |
| gligen_kwargs = cross_attention_kwargs.pop("gligen", None) | |
| attn_output = self.attn1( | |
| norm_hidden_states, | |
| encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, | |
| attention_mask=attention_mask, | |
| **cross_attention_kwargs, | |
| ) | |
| if self.norm_type == "ada_norm_zero": | |
| attn_output = gate_msa.unsqueeze(1) * attn_output | |
| elif self.norm_type == "ada_norm_single": | |
| attn_output = gate_msa * attn_output | |
| hidden_states = attn_output + hidden_states | |
| if hidden_states.ndim == 4: | |
| hidden_states = hidden_states.squeeze(1) | |
| # 2.5 GLIGEN Control | |
| if gligen_kwargs is not None: | |
| hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) | |
| # 3. Cross-Attention | |
| if self.attn2 is not None: | |
| if self.norm_type == "ada_norm": | |
| norm_hidden_states = self.norm2(hidden_states, timestep) | |
| elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: | |
| norm_hidden_states = self.norm2(hidden_states) | |
| elif self.norm_type == "ada_norm_single": | |
| # For PixArt norm2 isn't applied here: | |
| # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 | |
| norm_hidden_states = hidden_states | |
| elif self.norm_type == "ada_norm_continuous": | |
| norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) | |
| else: | |
| raise ValueError("Incorrect norm") | |
| if self.pos_embed is not None and self.norm_type != "ada_norm_single": | |
| # save the origin_hidden_states | |
| origin_hidden_states = norm_hidden_states.clone() | |
| norm_hidden_states = self.pos_embed(norm_hidden_states) | |
| cross_attention_kwargs["origin_hidden_states"] = origin_hidden_states | |
| attn_output = self.attn2( | |
| norm_hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| attention_mask=encoder_attention_mask, | |
| **cross_attention_kwargs, | |
| ) | |
| hidden_states = attn_output + hidden_states | |
| # delete the origin_hidden_states | |
| if cross_attention_kwargs is not None and "origin_hidden_states" in cross_attention_kwargs: | |
| cross_attention_kwargs.pop("origin_hidden_states") | |
| # 4. Feed-forward | |
| # i2vgen doesn't have this norm 🤷♂️ | |
| if self.norm_type == "ada_norm_continuous": | |
| norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) | |
| elif not self.norm_type == "ada_norm_single": | |
| norm_hidden_states = self.norm3(hidden_states) | |
| if self.norm_type == "ada_norm_zero": | |
| norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] | |
| if self.norm_type == "ada_norm_single": | |
| norm_hidden_states = self.norm2(hidden_states) | |
| norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp | |
| if self._chunk_size is not None: | |
| # "feed_forward_chunk_size" can be used to save memory | |
| ff_output = _chunked_feed_forward( | |
| self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale | |
| ) | |
| else: | |
| ff_output = self.ff(norm_hidden_states, scale=lora_scale) | |
| if self.norm_type == "ada_norm_zero": | |
| ff_output = gate_mlp.unsqueeze(1) * ff_output | |
| elif self.norm_type == "ada_norm_single": | |
| ff_output = gate_mlp * ff_output | |
| hidden_states = ff_output + hidden_states | |
| if hidden_states.ndim == 4: | |
| hidden_states = hidden_states.squeeze(1) | |
| return hidden_states | |
| def temp_attn_forward(self, additional_info=None): | |
| to_out = self.to_out | |
| if type(to_out) is torch.nn.modules.container.ModuleList: | |
| to_out = self.to_out[0] | |
| else: | |
| to_out = self.to_out | |
| def forward(hidden_states, encoder_hidden_states=None, attention_mask=None,temb=None,origin_hidden_states=None): | |
| residual = hidden_states | |
| if self.spatial_norm is not None: | |
| hidden_states = self.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ) | |
| attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
| if self.group_norm is not None: | |
| hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif self.norm_cross: | |
| encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) | |
| query = self.to_q(hidden_states) | |
| key = self.to_k(encoder_hidden_states) | |
| # strategies to manipulate the motion value embedding | |
| if additional_info is not None: | |
| # empirically, in the inference stage of camera motion | |
| # discarding the motion value embedding improves the text similarity of the generated video | |
| if additional_info['removeMFromV']: | |
| value = self.to_v(origin_hidden_states) | |
| elif hasattr(self,'vSpatial'): | |
| # during inference, the debiasing operation helps to generate more diverse videos | |
| # refer to the 'Figure.3 Right' in the paper for more details | |
| if additional_info['vSpatial_frameSubtraction']: | |
| value = self.to_v(self.vSpatial.forward_frameSubtraction(origin_hidden_states)) | |
| # during training, do not apply debias operation for motion learning | |
| else: | |
| value = self.to_v(self.vSpatial(origin_hidden_states)) | |
| else: | |
| value = self.to_v(origin_hidden_states) | |
| else: | |
| value = self.to_v(encoder_hidden_states) | |
| query = self.head_to_batch_dim(query) | |
| key = self.head_to_batch_dim(key) | |
| value = self.head_to_batch_dim(value) | |
| attention_probs = self.get_attention_scores(query, key, attention_mask) | |
| hidden_states = torch.bmm(attention_probs, value) | |
| hidden_states = self.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = to_out(hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
| if self.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / self.rescale_output_factor | |
| return hidden_states | |
| return forward | |
| def register_recr(net_, count, name, config=None): | |
| if net_.__class__.__name__ == 'BasicTransformerBlock': | |
| BasicTransformerBlock_forward_ = partial(BasicTransformerBlock_forward, net_) | |
| net_.forward = BasicTransformerBlock_forward_ | |
| if net_.__class__.__name__ == 'Attention': | |
| block_name = name.split('.attn')[0] | |
| if config is not None and block_name in set([l.split('.attn')[0].split('.pos_embed')[0] for l in config.model.embedding_layers]): | |
| additional_info = {} | |
| additional_info['layer_name'] = name | |
| additional_info['removeMFromV'] = config.strategy.get('removeMFromV', False) | |
| additional_info['vSpatial_frameSubtraction'] = config.strategy.get('vSpatial_frameSubtraction', False) | |
| net_.forward = temp_attn_forward(net_, additional_info) | |
| # print('register Motion V embedding at ', block_name) | |
| return count + 1 | |
| else: | |
| return count | |
| elif hasattr(net_, 'children'): | |
| for net_name, net__ in dict(net_.named_children()).items(): | |
| count = register_recr(net__, count, name = name + '.' + net_name, config=config) | |
| return count | |
| sub_nets = unet.named_children() | |
| for net in sub_nets: | |
| register_recr(net[1], 0,name = net[0], config=config) | |