import torch import torch.nn as nn import torch import torch.nn as nn import torch.nn.functional as F from flash_attn import flash_attn_varlen_func, flash_attn_varlen_qkvpacked_func def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): # x shape: bsz, seqlen, self.n_local_heads, self.head_hidden_dim / 2 # the last shape is "self.hidden_dim / 2" because we convert to complex assert x.ndim == 4 assert freqs_cis.shape == (x.shape[0], x.shape[1], x.shape[-1]), \ f'x shape: {x.shape}, freqs_cis shape: {freqs_cis.shape}' # reshape freq cis to match and apply pointwise multiply # new shape: bsz, seq_len, 1, self.head_hidden_dim / 2 shape = [x.shape[0], x.shape[1], 1, x.shape[-1]] return freqs_cis.view(*shape) def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, ): xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = reshape_for_broadcast(freqs_cis, xq_) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) class Attention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, qk_norm: bool = False, attn_drop: float = 0., proj_drop: float = 0., norm_layer: nn.Module = nn.LayerNorm, flash_attention: bool = True ) -> None: super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.fused_attn = flash_attention self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qk_norm = qk_norm self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor, seq_len, cu_seqlens, max_seqlen, cu_seqlens_k, max_seqlen_k, rotary_pos_emb=None, incremental_state=None, nopadding=True) -> torch.Tensor: B, N, C = x.shape if self.fused_attn: if nopadding: qkv = self.qkv(x) qkv = qkv.view(B * N, self.num_heads * 3, self.head_dim) q, k, v = qkv.split([self.num_heads] * 3, dim=1) q, k = self.q_norm(q), self.k_norm(k) q = q.view(B, N, self.num_heads, self.head_dim) k = k.view(B, N, self.num_heads, self.head_dim) v = v.view(B, N, self.num_heads, self.head_dim) if rotary_pos_emb is not None: q, k = apply_rotary_emb(q, k, rotary_pos_emb) if incremental_state is not None: if "prev_k" in incremental_state: prev_k = incremental_state["prev_k"] k = torch.cat([prev_k, k], dim=1) if "cur_k" not in incremental_state: incremental_state["cur_k"] = {} incremental_state["cur_k"] = k if "prev_v" in incremental_state: prev_v = incremental_state["prev_v"] v = torch.cat([prev_v, v], dim=1) if "cur_v" not in incremental_state: incremental_state["cur_v"] = {} incremental_state["cur_v"] = v q = q.view(B * N, self.num_heads, self.head_dim) k = k.view(-1, self.num_heads, self.head_dim) v = v.view(-1, self.num_heads, self.head_dim) x = flash_attn_varlen_func( q=q, k=k, v=v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen_k, dropout_p=self.attn_drop.p if self.training else 0., ) else: if incremental_state is not None: raise NotImplementedError("It is designed for batching inference. AR-chunk is not supported currently.") qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) if self.qk_norm: q, k, v = qkv.unbind(2) q, k = self.q_norm(q), self.k_norm(k) # re-bind qkv = torch.stack((q, k, v), dim=2) # pack qkv with seq_len qkv_collect = [] for i in range(qkv.shape[0]): qkv_collect.append( qkv[i, :seq_len[i], :, :, :] ) qkv = torch.cat(qkv_collect, dim=0) x = flash_attn_varlen_qkvpacked_func(qkv=qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=self.attn_drop.p if self.training else 0.) # unpack and pad 0 x_collect = [] for i in range(B): x_collect.append( x[cu_seqlens[i]:cu_seqlens[i+1], :, :] ) x = torch.nn.utils.rnn.pad_sequence(x_collect, batch_first=True, padding_value=0) else: q = q * self.scale attn = q @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = attn @ v x = x.transpose(1, 2) x = x.reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x def modulate(x, shift, scale): return x * (1 + scale) + shift class FinalLayer(nn.Module): """ The final layer of DiT. """ def __init__(self, hidden_size, out_channels): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, out_channels, bias=True) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) ) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=2) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x class DiTBlock(nn.Module): """ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. """ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, ffn_type="conv1d_conv1d", ffn_gated_glu=True, ffn_act_layer="gelu", ffn_conv_kernel_size=5, **block_kwargs): super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) if ffn_type == "vanilla_mlp": from timm.models.vision_transformer import Mlp mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) else: raise NotImplementedError(f"FFN type {ffn_type} is not implemented") self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) def forward(self, x, c, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, cu_maxlen_k, mask, rotary_pos_emb=None, incremental_state=None, nopadding=True): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=2) x_ = modulate(self.norm1(x), shift_msa, scale_msa) if incremental_state is not None: if "attn_kvcache" not in incremental_state: incremental_state["attn_kvcache"] = {} inc_attn = incremental_state["attn_kvcache"] else: inc_attn = None x_ = self.attn(x_, seq_len=seq_len, cu_seqlens=cu_seqlens, max_seqlen=cu_maxlen, cu_seqlens_k=cu_seqlens_k, max_seqlen_k=cu_maxlen_k, rotary_pos_emb=rotary_pos_emb, incremental_state=inc_attn, nopadding=nopadding) if not nopadding: x_ = x_ * mask[:, :, None] x = x + gate_msa * x_ x_ = modulate(self.norm2(x), shift_mlp, scale_mlp) x_ = self.mlp(x_) if not nopadding: x_ = x_ * mask[:, :, None] x = x + gate_mlp * x_ return x