from typing import Optional import torch import torch.nn.functional as F from diffusers.models.attention_processor import Attention from diffusers.models.transformers.transformer_ltx import apply_rotary_emb class NAGLTXVideoAttentionProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector. """ def __init__(self, nag_scale=1.0, nag_tau=2.5, nag_alpha=0.5): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) self.nag_scale = nag_scale self.nag_tau = nag_tau self.nag_alpha = nag_alpha def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: apply_guidance = self.nag_scale > 1 and encoder_hidden_states is not None if apply_guidance: origin_batch_size = len(encoder_hidden_states) - len(hidden_states) assert origin_batch_size > 0 batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attention_mask.view(attn.heads, -1, attention_mask.shape[-1]) if encoder_hidden_states is None: encoder_hidden_states = hidden_states query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.norm_q(query) key = attn.norm_k(key) if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) if apply_guidance: query = torch.cat([query, query[-origin_batch_size:]], dim=0) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.to(query.dtype) if apply_guidance: hidden_states_negative, hidden_states_positive = hidden_states[-origin_batch_size:], hidden_states[-origin_batch_size * 2:-origin_batch_size] hidden_states_guidance = hidden_states_positive * self.nag_scale - hidden_states_negative * (self.nag_scale - 1) norm_positive = torch.norm(hidden_states_positive, p=2, dim=-1, keepdim=True).expand(*hidden_states_positive.shape) norm_guidance = torch.norm(hidden_states_guidance, p=2, dim=-1, keepdim=True).expand(*hidden_states_positive.shape) scale = norm_guidance / norm_positive hidden_states_guidance = hidden_states_guidance * torch.minimum(scale, scale.new_ones(1) * self.nag_tau) / scale hidden_states_guidance = hidden_states_guidance * self.nag_alpha + hidden_states_positive * (1 - self.nag_alpha) hidden_states = torch.cat([hidden_states[:-origin_batch_size * 2], hidden_states_guidance], dim=0) hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) return hidden_states