Spaces:
Build error
Build error
| # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py | |
| from typing import Any, Dict, Optional | |
| import torch | |
| from diffusers.models.attention import AdaLayerNorm, Attention, FeedForward | |
| from diffusers.models.embeddings import SinusoidalPositionalEmbedding | |
| from einops import rearrange | |
| from torch import nn | |
| from diffusers.models.attention import * | |
| from diffusers.models.attention_processor import * | |
| class BasicTransformerBlock(nn.Module): | |
| r""" | |
| A basic Transformer block. | |
| Parameters: | |
| dim (`int`): The number of channels in the input and output. | |
| num_attention_heads (`int`): The number of heads to use for multi-head attention. | |
| attention_head_dim (`int`): The number of channels in each head. | |
| dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | |
| cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. | |
| activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. | |
| num_embeds_ada_norm (: | |
| obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. | |
| attention_bias (: | |
| obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. | |
| only_cross_attention (`bool`, *optional*): | |
| Whether to use only cross-attention layers. In this case two cross attention layers are used. | |
| double_self_attention (`bool`, *optional*): | |
| Whether to use two self-attention layers. In this case no cross attention layers are used. | |
| upcast_attention (`bool`, *optional*): | |
| Whether to upcast the attention computation to float32. This is useful for mixed precision training. | |
| norm_elementwise_affine (`bool`, *optional*, defaults to `True`): | |
| Whether to use learnable elementwise affine parameters for normalization. | |
| norm_type (`str`, *optional*, defaults to `"layer_norm"`): | |
| The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. | |
| final_dropout (`bool` *optional*, defaults to False): | |
| Whether to apply a final dropout after the last feed-forward layer. | |
| attention_type (`str`, *optional*, defaults to `"default"`): | |
| The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. | |
| positional_embeddings (`str`, *optional*, defaults to `None`): | |
| The type of positional embeddings to apply to. | |
| num_positional_embeddings (`int`, *optional*, defaults to `None`): | |
| The maximum number of positional embeddings to apply. | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_attention_heads: int, | |
| attention_head_dim: int, | |
| dropout=0.0, | |
| cross_attention_dim: Optional[int] = None, | |
| activation_fn: str = "geglu", | |
| num_embeds_ada_norm: Optional[int] = None, | |
| attention_bias: bool = False, | |
| only_cross_attention: bool = False, | |
| double_self_attention: bool = False, | |
| upcast_attention: bool = False, | |
| norm_elementwise_affine: bool = True, | |
| norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' | |
| norm_eps: float = 1e-5, | |
| final_dropout: bool = False, | |
| attention_type: str = "default", | |
| positional_embeddings: Optional[str] = None, | |
| num_positional_embeddings: Optional[int] = None, | |
| ): | |
| super().__init__() | |
| self.only_cross_attention = only_cross_attention | |
| self.use_ada_layer_norm_zero = ( | |
| num_embeds_ada_norm is not None | |
| ) and norm_type == "ada_norm_zero" | |
| self.use_ada_layer_norm = ( | |
| num_embeds_ada_norm is not None | |
| ) and norm_type == "ada_norm" | |
| self.use_ada_layer_norm_single = norm_type == "ada_norm_single" | |
| self.use_layer_norm = norm_type == "layer_norm" | |
| if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: | |
| raise ValueError( | |
| f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" | |
| f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." | |
| ) | |
| if positional_embeddings and (num_positional_embeddings is None): | |
| raise ValueError( | |
| "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." | |
| ) | |
| if positional_embeddings == "sinusoidal": | |
| self.pos_embed = SinusoidalPositionalEmbedding( | |
| dim, max_seq_length=num_positional_embeddings | |
| ) | |
| else: | |
| self.pos_embed = None | |
| # Define 3 blocks. Each block has its own normalization layer. | |
| # 1. Self-Attn | |
| if self.use_ada_layer_norm: | |
| self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) | |
| elif self.use_ada_layer_norm_zero: | |
| self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) | |
| else: | |
| self.norm1 = nn.LayerNorm( | |
| dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps | |
| ) | |
| self.attn1 = Attention( | |
| query_dim=dim, | |
| heads=num_attention_heads, | |
| dim_head=attention_head_dim, | |
| dropout=dropout, | |
| bias=attention_bias, | |
| cross_attention_dim=cross_attention_dim if only_cross_attention else None, | |
| upcast_attention=upcast_attention, | |
| ) | |
| # 2. Cross-Attn | |
| if cross_attention_dim is not None or double_self_attention: | |
| # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. | |
| # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during | |
| # the second cross attention block. | |
| self.norm2 = ( | |
| AdaLayerNorm(dim, num_embeds_ada_norm) | |
| if self.use_ada_layer_norm | |
| else nn.LayerNorm( | |
| dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps | |
| ) | |
| ) | |
| self.attn2 = Attention( | |
| query_dim=dim, | |
| cross_attention_dim=cross_attention_dim | |
| if not double_self_attention | |
| else None, | |
| heads=num_attention_heads, | |
| dim_head=attention_head_dim, | |
| dropout=dropout, | |
| bias=attention_bias, | |
| upcast_attention=upcast_attention, | |
| ) # is self-attn if encoder_hidden_states is none | |
| else: | |
| self.norm2 = None | |
| self.attn2 = None | |
| # 3. Feed-forward | |
| if not self.use_ada_layer_norm_single: | |
| self.norm3 = nn.LayerNorm( | |
| dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps | |
| ) | |
| self.ff = FeedForward( | |
| dim, | |
| dropout=dropout, | |
| activation_fn=activation_fn, | |
| final_dropout=final_dropout, | |
| ) | |
| # 4. Fuser | |
| if attention_type == "gated" or attention_type == "gated-text-image": | |
| self.fuser = GatedSelfAttentionDense( | |
| dim, cross_attention_dim, num_attention_heads, attention_head_dim | |
| ) | |
| # 5. Scale-shift for PixArt-Alpha. | |
| if self.use_ada_layer_norm_single: | |
| self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) | |
| # let chunk size default to None | |
| self._chunk_size = None | |
| self._chunk_dim = 0 | |
| def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): | |
| # Sets chunk feed-forward | |
| self._chunk_size = chunk_size | |
| self._chunk_dim = dim | |
| def forward( | |
| self, | |
| hidden_states: torch.FloatTensor, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
| timestep: Optional[torch.LongTensor] = None, | |
| cross_attention_kwargs: Dict[str, Any] = None, | |
| class_labels: Optional[torch.LongTensor] = None, | |
| ) -> torch.FloatTensor: | |
| # Notice that normalization is always applied before the real computation in the following blocks. | |
| # 0. Self-Attention | |
| batch_size = hidden_states.shape[0] | |
| if self.use_ada_layer_norm: | |
| norm_hidden_states = self.norm1(hidden_states, timestep) | |
| elif self.use_ada_layer_norm_zero: | |
| norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( | |
| hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype | |
| ) | |
| elif self.use_layer_norm: | |
| norm_hidden_states = self.norm1(hidden_states) | |
| elif self.use_ada_layer_norm_single: | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( | |
| self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) | |
| ).chunk(6, dim=1) | |
| norm_hidden_states = self.norm1(hidden_states) | |
| norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa | |
| norm_hidden_states = norm_hidden_states.squeeze(1) | |
| else: | |
| raise ValueError("Incorrect norm used") | |
| if self.pos_embed is not None: | |
| norm_hidden_states = self.pos_embed(norm_hidden_states) | |
| # 1. Retrieve lora scale. | |
| lora_scale = ( | |
| cross_attention_kwargs.get("scale", 1.0) | |
| if cross_attention_kwargs is not None | |
| else 1.0 | |
| ) | |
| # 2. Prepare GLIGEN inputs | |
| cross_attention_kwargs = ( | |
| cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} | |
| ) | |
| gligen_kwargs = cross_attention_kwargs.pop("gligen", None) | |
| attn_output = self.attn1( | |
| norm_hidden_states, | |
| encoder_hidden_states=encoder_hidden_states | |
| if self.only_cross_attention | |
| else None, | |
| attention_mask=attention_mask, | |
| **cross_attention_kwargs, | |
| ) | |
| if self.use_ada_layer_norm_zero: | |
| attn_output = gate_msa.unsqueeze(1) * attn_output | |
| elif self.use_ada_layer_norm_single: | |
| attn_output = gate_msa * attn_output | |
| hidden_states = attn_output + hidden_states | |
| if hidden_states.ndim == 4: | |
| hidden_states = hidden_states.squeeze(1) | |
| # 2.5 GLIGEN Control | |
| if gligen_kwargs is not None: | |
| hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) | |
| # 3. Cross-Attention | |
| if self.attn2 is not None: | |
| if self.use_ada_layer_norm: | |
| norm_hidden_states = self.norm2(hidden_states, timestep) | |
| elif self.use_ada_layer_norm_zero or self.use_layer_norm: | |
| norm_hidden_states = self.norm2(hidden_states) | |
| elif self.use_ada_layer_norm_single: | |
| # For PixArt norm2 isn't applied here: | |
| # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 | |
| norm_hidden_states = hidden_states | |
| else: | |
| raise ValueError("Incorrect norm") | |
| if self.pos_embed is not None and self.use_ada_layer_norm_single is False: | |
| norm_hidden_states = self.pos_embed(norm_hidden_states) | |
| attn_output = self.attn2( | |
| norm_hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| attention_mask=encoder_attention_mask, | |
| **cross_attention_kwargs, | |
| ) | |
| hidden_states = attn_output + hidden_states | |
| # 4. Feed-forward | |
| if not self.use_ada_layer_norm_single: | |
| norm_hidden_states = self.norm3(hidden_states) | |
| if self.use_ada_layer_norm_zero: | |
| norm_hidden_states = ( | |
| norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] | |
| ) | |
| if self.use_ada_layer_norm_single: | |
| norm_hidden_states = self.norm2(hidden_states) | |
| norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp | |
| ff_output = self.ff(norm_hidden_states, scale=lora_scale) | |
| if self.use_ada_layer_norm_zero: | |
| ff_output = gate_mlp.unsqueeze(1) * ff_output | |
| elif self.use_ada_layer_norm_single: | |
| ff_output = gate_mlp * ff_output | |
| hidden_states = ff_output + hidden_states | |
| if hidden_states.ndim == 4: | |
| hidden_states = hidden_states.squeeze(1) | |
| return hidden_states | |
| class TemporalBasicTransformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_attention_heads: int, | |
| attention_head_dim: int, | |
| dropout=0.0, | |
| cross_attention_dim: Optional[int] = None, | |
| activation_fn: str = "geglu", | |
| num_embeds_ada_norm: Optional[int] = None, | |
| attention_bias: bool = False, | |
| only_cross_attention: bool = False, | |
| upcast_attention: bool = False, | |
| unet_use_cross_frame_attention=None, | |
| unet_use_temporal_attention=None, | |
| ): | |
| super().__init__() | |
| self.only_cross_attention = only_cross_attention | |
| self.use_ada_layer_norm = num_embeds_ada_norm is not None | |
| self.unet_use_cross_frame_attention = unet_use_cross_frame_attention | |
| self.unet_use_temporal_attention = unet_use_temporal_attention | |
| # SC-Attn | |
| self.attn1 = Attention( | |
| query_dim=dim, | |
| heads=num_attention_heads, | |
| dim_head=attention_head_dim, | |
| dropout=dropout, | |
| bias=attention_bias, | |
| upcast_attention=upcast_attention, | |
| ) | |
| self.norm1 = ( | |
| AdaLayerNorm(dim, num_embeds_ada_norm) | |
| if self.use_ada_layer_norm | |
| else nn.LayerNorm(dim) | |
| ) | |
| # Cross-Attn | |
| if cross_attention_dim is not None: | |
| self.attn2 = Attention( | |
| query_dim=dim, | |
| cross_attention_dim=cross_attention_dim, | |
| heads=num_attention_heads, | |
| dim_head=attention_head_dim, | |
| dropout=dropout, | |
| bias=attention_bias, | |
| upcast_attention=upcast_attention, | |
| ) | |
| else: | |
| self.attn2 = None | |
| if cross_attention_dim is not None: | |
| self.norm2 = ( | |
| AdaLayerNorm(dim, num_embeds_ada_norm) | |
| if self.use_ada_layer_norm | |
| else nn.LayerNorm(dim) | |
| ) | |
| else: | |
| self.norm2 = None | |
| # Feed-forward | |
| self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) | |
| self.norm3 = nn.LayerNorm(dim) | |
| self.use_ada_layer_norm_zero = False | |
| # Temp-Attn | |
| assert unet_use_temporal_attention is not None | |
| if unet_use_temporal_attention: | |
| self.attn_temp = Attention( | |
| query_dim=dim, | |
| heads=num_attention_heads, | |
| dim_head=attention_head_dim, | |
| dropout=dropout, | |
| bias=attention_bias, | |
| upcast_attention=upcast_attention, | |
| ) | |
| nn.init.zeros_(self.attn_temp.to_out[0].weight.data) | |
| self.norm_temp = ( | |
| AdaLayerNorm(dim, num_embeds_ada_norm) | |
| if self.use_ada_layer_norm | |
| else nn.LayerNorm(dim) | |
| ) | |
| def forward( | |
| self, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| timestep=None, | |
| attention_mask=None, | |
| video_length=None, | |
| ): | |
| norm_hidden_states = ( | |
| self.norm1(hidden_states, timestep) | |
| if self.use_ada_layer_norm | |
| else self.norm1(hidden_states) | |
| ) | |
| if self.unet_use_cross_frame_attention: | |
| hidden_states = ( | |
| self.attn1( | |
| norm_hidden_states, | |
| attention_mask=attention_mask, | |
| video_length=video_length, | |
| ) | |
| + hidden_states | |
| ) | |
| else: | |
| hidden_states = ( | |
| self.attn1(norm_hidden_states, attention_mask=attention_mask) | |
| + hidden_states | |
| ) | |
| if self.attn2 is not None: | |
| # Cross-Attention | |
| norm_hidden_states = ( | |
| self.norm2(hidden_states, timestep) | |
| if self.use_ada_layer_norm | |
| else self.norm2(hidden_states) | |
| ) | |
| hidden_states = ( | |
| self.attn2( | |
| norm_hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| attention_mask=attention_mask, | |
| ) | |
| + hidden_states | |
| ) | |
| # Feed-forward | |
| hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states | |
| # Temporal-Attention | |
| if self.unet_use_temporal_attention: | |
| d = hidden_states.shape[1] | |
| hidden_states = rearrange( | |
| hidden_states, "(b f) d c -> (b d) f c", f=video_length | |
| ) | |
| norm_hidden_states = ( | |
| self.norm_temp(hidden_states, timestep) | |
| if self.use_ada_layer_norm | |
| else self.norm_temp(hidden_states) | |
| ) | |
| hidden_states = self.attn_temp(norm_hidden_states) + hidden_states | |
| hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) | |
| return hidden_states | |
| class ResidualTemporalBasicTransformerBlock(TemporalBasicTransformerBlock): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_attention_heads: int, | |
| attention_head_dim: int, | |
| dropout=0.0, | |
| cross_attention_dim: Optional[int] = None, | |
| activation_fn: str = "geglu", | |
| num_embeds_ada_norm: Optional[int] = None, | |
| attention_bias: bool = False, | |
| only_cross_attention: bool = False, | |
| upcast_attention: bool = False, | |
| unet_use_cross_frame_attention=None, | |
| unet_use_temporal_attention=None, | |
| ): | |
| super(TemporalBasicTransformerBlock, self).__init__() | |
| self.only_cross_attention = only_cross_attention | |
| self.use_ada_layer_norm = num_embeds_ada_norm is not None | |
| self.unet_use_cross_frame_attention = unet_use_cross_frame_attention | |
| self.unet_use_temporal_attention = unet_use_temporal_attention | |
| # SC-Attn | |
| self.attn1 = ResidualAttention( | |
| query_dim=dim, | |
| heads=num_attention_heads, | |
| dim_head=attention_head_dim, | |
| dropout=dropout, | |
| bias=attention_bias, | |
| upcast_attention=upcast_attention, | |
| ) | |
| self.norm1 = ( | |
| AdaLayerNorm(dim, num_embeds_ada_norm) | |
| if self.use_ada_layer_norm | |
| else nn.LayerNorm(dim) | |
| ) | |
| # Cross-Attn | |
| if cross_attention_dim is not None: | |
| self.attn2 = ResidualAttention( | |
| query_dim=dim, | |
| cross_attention_dim=cross_attention_dim, | |
| heads=num_attention_heads, | |
| dim_head=attention_head_dim, | |
| dropout=dropout, | |
| bias=attention_bias, | |
| upcast_attention=upcast_attention, | |
| ) | |
| else: | |
| self.attn2 = None | |
| if cross_attention_dim is not None: | |
| self.norm2 = ( | |
| AdaLayerNorm(dim, num_embeds_ada_norm) | |
| if self.use_ada_layer_norm | |
| else nn.LayerNorm(dim) | |
| ) | |
| else: | |
| self.norm2 = None | |
| # Feed-forward | |
| self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) | |
| self.norm3 = nn.LayerNorm(dim) | |
| self.use_ada_layer_norm_zero = False | |
| # Temp-Attn | |
| assert unet_use_temporal_attention is not None | |
| if unet_use_temporal_attention: | |
| self.attn_temp = Attention( | |
| query_dim=dim, | |
| heads=num_attention_heads, | |
| dim_head=attention_head_dim, | |
| dropout=dropout, | |
| bias=attention_bias, | |
| upcast_attention=upcast_attention, | |
| ) | |
| nn.init.zeros_(self.attn_temp.to_out[0].weight.data) | |
| self.norm_temp = ( | |
| AdaLayerNorm(dim, num_embeds_ada_norm) | |
| if self.use_ada_layer_norm | |
| else nn.LayerNorm(dim) | |
| ) | |
| def forward( | |
| self, | |
| hidden_states, | |
| encoder_hidden_states=None, | |
| timestep=None, | |
| attention_mask=None, | |
| video_length=None, | |
| block_idx: Optional[int] = None, | |
| additional_residuals: Optional[Dict[str, torch.FloatTensor]] = None | |
| ): | |
| norm_hidden_states = ( | |
| self.norm1(hidden_states, timestep) | |
| if self.use_ada_layer_norm | |
| else self.norm1(hidden_states) | |
| ) | |
| if self.unet_use_cross_frame_attention: | |
| hidden_states = ( | |
| self.attn1( | |
| norm_hidden_states, | |
| attention_mask=attention_mask, | |
| video_length=video_length, | |
| block_idx=block_idx, | |
| additional_residuals=additional_residuals, | |
| ) | |
| + hidden_states | |
| ) | |
| else: | |
| hidden_states = ( | |
| self.attn1(norm_hidden_states, attention_mask=attention_mask, | |
| block_idx=block_idx, | |
| additional_residuals=additional_residuals | |
| ) | |
| + hidden_states | |
| ) | |
| if self.attn2 is not None: | |
| # Cross-Attention | |
| norm_hidden_states = ( | |
| self.norm2(hidden_states, timestep) | |
| if self.use_ada_layer_norm | |
| else self.norm2(hidden_states) | |
| ) | |
| hidden_states = ( | |
| self.attn2( | |
| norm_hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| attention_mask=attention_mask, | |
| block_idx=block_idx, | |
| additional_residuals=additional_residuals, | |
| ) | |
| + hidden_states | |
| ) | |
| # Feed-forward | |
| hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states | |
| # Temporal-Attention | |
| if self.unet_use_temporal_attention: | |
| d = hidden_states.shape[1] | |
| hidden_states = rearrange( | |
| hidden_states, "(b f) d c -> (b d) f c", f=video_length | |
| ) | |
| norm_hidden_states = ( | |
| self.norm_temp(hidden_states, timestep) | |
| if self.use_ada_layer_norm | |
| else self.norm_temp(hidden_states) | |
| ) | |
| hidden_states = self.attn_temp(norm_hidden_states) + hidden_states | |
| hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) | |
| return hidden_states | |
| class ResidualAttention(Attention): | |
| def set_use_memory_efficient_attention_xformers( | |
| self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None | |
| ): | |
| is_lora = hasattr(self, "processor") and isinstance( | |
| self.processor, | |
| LORA_ATTENTION_PROCESSORS, | |
| ) | |
| is_custom_diffusion = hasattr(self, "processor") and isinstance( | |
| self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor) | |
| ) | |
| is_added_kv_processor = hasattr(self, "processor") and isinstance( | |
| self.processor, | |
| ( | |
| AttnAddedKVProcessor, | |
| AttnAddedKVProcessor2_0, | |
| SlicedAttnAddedKVProcessor, | |
| XFormersAttnAddedKVProcessor, | |
| LoRAAttnAddedKVProcessor, | |
| ), | |
| ) | |
| if use_memory_efficient_attention_xformers: | |
| if is_added_kv_processor and (is_lora or is_custom_diffusion): | |
| raise NotImplementedError( | |
| f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}" | |
| ) | |
| if not is_xformers_available(): | |
| raise ModuleNotFoundError( | |
| ( | |
| "Refer to https://github.com/facebookresearch/xformers for more information on how to install" | |
| " xformers" | |
| ), | |
| name="xformers", | |
| ) | |
| elif not torch.cuda.is_available(): | |
| raise ValueError( | |
| "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" | |
| " only available for GPU " | |
| ) | |
| else: | |
| try: | |
| # Make sure we can run the memory efficient attention | |
| _ = xformers.ops.memory_efficient_attention( | |
| torch.randn((1, 2, 40), device="cuda"), | |
| torch.randn((1, 2, 40), device="cuda"), | |
| torch.randn((1, 2, 40), device="cuda"), | |
| ) | |
| except Exception as e: | |
| raise e | |
| if is_lora: | |
| # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers | |
| # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0? | |
| processor = LoRAXFormersAttnProcessor( | |
| hidden_size=self.processor.hidden_size, | |
| cross_attention_dim=self.processor.cross_attention_dim, | |
| rank=self.processor.rank, | |
| attention_op=attention_op, | |
| ) | |
| processor.load_state_dict(self.processor.state_dict()) | |
| processor.to(self.processor.to_q_lora.up.weight.device) | |
| elif is_custom_diffusion: | |
| processor = CustomDiffusionXFormersAttnProcessor( | |
| train_kv=self.processor.train_kv, | |
| train_q_out=self.processor.train_q_out, | |
| hidden_size=self.processor.hidden_size, | |
| cross_attention_dim=self.processor.cross_attention_dim, | |
| attention_op=attention_op, | |
| ) | |
| processor.load_state_dict(self.processor.state_dict()) | |
| if hasattr(self.processor, "to_k_custom_diffusion"): | |
| processor.to(self.processor.to_k_custom_diffusion.weight.device) | |
| elif is_added_kv_processor: | |
| # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP | |
| # which uses this type of cross attention ONLY because the attention mask of format | |
| # [0, ..., -10.000, ..., 0, ...,] is not supported | |
| # throw warning | |
| logger.info( | |
| "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." | |
| ) | |
| processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) | |
| else: | |
| processor = ResidualXFormersAttnProcessor(attention_op=attention_op) | |
| else: | |
| if is_lora: | |
| attn_processor_class = ( | |
| LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor | |
| ) | |
| processor = attn_processor_class( | |
| hidden_size=self.processor.hidden_size, | |
| cross_attention_dim=self.processor.cross_attention_dim, | |
| rank=self.processor.rank, | |
| ) | |
| processor.load_state_dict(self.processor.state_dict()) | |
| processor.to(self.processor.to_q_lora.up.weight.device) | |
| elif is_custom_diffusion: | |
| processor = CustomDiffusionAttnProcessor( | |
| train_kv=self.processor.train_kv, | |
| train_q_out=self.processor.train_q_out, | |
| hidden_size=self.processor.hidden_size, | |
| cross_attention_dim=self.processor.cross_attention_dim, | |
| ) | |
| processor.load_state_dict(self.processor.state_dict()) | |
| if hasattr(self.processor, "to_k_custom_diffusion"): | |
| processor.to(self.processor.to_k_custom_diffusion.weight.device) | |
| else: | |
| # set attention processor | |
| # We use the AttnProcessor2_0 by default when torch 2.x is used which uses | |
| # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention | |
| # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 | |
| processor = ( | |
| AttnProcessor2_0() | |
| if hasattr(F, "scaled_dot_product_attention") and self.scale_qk | |
| else AttnProcessor() | |
| ) | |
| self.set_processor(processor) | |
| def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, | |
| block_idx: Optional[int] = None, additional_residuals: Optional[Dict[str, torch.FloatTensor]] = None, | |
| is_self_attn: Optional[bool] = None, **cross_attention_kwargs): | |
| # The `Attention` class can call different attention processors / attention functions | |
| # here we simply pass along all tensors to the selected processor class | |
| # For standard processors that are defined here, `**cross_attention_kwargs` is empty | |
| return self.processor( | |
| self, | |
| hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| attention_mask=attention_mask, | |
| block_idx=block_idx, | |
| additional_residuals=additional_residuals, | |
| is_self_attn=is_self_attn, | |
| **cross_attention_kwargs, | |
| ) | |
| class ResidualXFormersAttnProcessor(XFormersAttnProcessor): | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states: torch.FloatTensor, | |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| temb: Optional[torch.FloatTensor] = None, | |
| block_idx: Optional[int] = None, | |
| additional_residuals: Optional[Dict[str, torch.FloatTensor]] = None, | |
| is_self_attn: Optional[bool] = None | |
| ): | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
| batch_size, key_tokens, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ) | |
| attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size) | |
| if attention_mask is not None: | |
| # expand our mask's singleton query_tokens dimension: | |
| # [batch*heads, 1, key_tokens] -> | |
| # [batch*heads, query_tokens, key_tokens] | |
| # so that it can be added as a bias onto the attention scores that xformers computes: | |
| # [batch*heads, query_tokens, key_tokens] | |
| # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. | |
| _, query_tokens, _ = hidden_states.shape | |
| attention_mask = attention_mask.expand(-1, query_tokens, -1) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| query = attn.to_q(hidden_states) | |
| # newly added | |
| if is_self_attn and additional_residuals and f"block_{block_idx}_self_attn_q" in additional_residuals: | |
| query = query + additional_residuals[f"block_{block_idx}_self_attn_q"] | |
| elif not is_self_attn and additional_residuals and f"block_{block_idx}_cross_attn_q" in additional_residuals: | |
| query = query + additional_residuals[f"block_{block_idx}_cross_attn_q"] | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| if not is_self_attn and additional_residuals and f"block_{block_idx}_cross_attn_c" in additional_residuals: | |
| not_uc = torch.abs(encoder_hidden_states - torch.zeros_like(encoder_hidden_states)).mean(dim=[1, 2], keepdim=True) < 1e-4 | |
| encoder_hidden_states = encoder_hidden_states + additional_residuals[f"block_{block_idx}_cross_attn_c"] * not_uc | |
| # encoder_hidden_states[not_uc] = encoder_hidden_states[not_uc] + \ | |
| # additional_residuals[f"block_{block_idx}_cross_attn_c"][not_uc] | |
| # encoder_hidden_states[~not_uc] = encoder_hidden_states[~not_uc] + \ | |
| # additional_residuals[f"block_{block_idx}_cross_attn_c"][~not_uc] * 0. | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| # newly added | |
| if is_self_attn and additional_residuals and f"block_{block_idx}_self_attn_k" in additional_residuals: | |
| key = key + additional_residuals[f"block_{block_idx}_self_attn_k"] | |
| elif not is_self_attn and additional_residuals and f"block_{block_idx}_cross_attn_k" in additional_residuals: | |
| key = key + additional_residuals[f"block_{block_idx}_cross_attn_k"] | |
| if is_self_attn and additional_residuals and f"block_{block_idx}_self_attn_v" in additional_residuals: | |
| value = value + additional_residuals[f"block_{block_idx}_self_attn_v"] | |
| elif not is_self_attn and additional_residuals and f"block_{block_idx}_cross_attn_v" in additional_residuals: | |
| value = value + additional_residuals[f"block_{block_idx}_cross_attn_v"] | |
| query = attn.head_to_batch_dim(query).contiguous() | |
| key = attn.head_to_batch_dim(key).contiguous() | |
| value = attn.head_to_batch_dim(value).contiguous() | |
| hidden_states = xformers.ops.memory_efficient_attention( | |
| query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale | |
| ) | |
| hidden_states = hidden_states.to(query.dtype) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states |