|
""" |
|
Based on: https://github.com/lucidrains/flamingo-pytorch |
|
""" |
|
|
|
import math |
|
from typing import Optional, Tuple, Union |
|
from .modeling_internlm2 import InternLM2RMSNorm, InternLM2RotaryEmbedding |
|
from .configuration_mixin import MixinConfig |
|
import torch |
|
from einops import rearrange, repeat |
|
from einops_exts import rearrange_many |
|
from torch import einsum, nn |
|
|
|
from transformers.activations import ACT2FN |
|
|
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func |
|
|
|
|
|
def rotate_half(x): |
|
"""Rotates half the hidden dims of the input.""" |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2:] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
def apply_rotary_pos_emb_single(q, cos, sin, position_ids, unsqueeze_dim=1): |
|
"""Applies Rotary Position Embedding to the query and key tensors.""" |
|
cos = cos[position_ids].unsqueeze(unsqueeze_dim).float() |
|
sin = sin[position_ids].unsqueeze(unsqueeze_dim).float() |
|
q_dtype = q.dtype |
|
q = q.float() |
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
return q_embed.to(dtype=q_dtype) |
|
|
|
class CrossAttention(nn.Module): |
|
def __init__( |
|
self, |
|
config: MixinConfig |
|
): |
|
super().__init__() |
|
dim = config.language_dim |
|
dim_visual = config.vision_dim |
|
dim_head = config.head_dim |
|
heads = config.num_heads |
|
|
|
self.scale = dim_head**-0.5 |
|
self.heads = heads |
|
inner_dim = dim_head * heads |
|
self.head_dim = dim_head |
|
self.max_position_embeddings = 32768 |
|
|
|
self.to_q = nn.Linear(dim, inner_dim, bias=False) |
|
self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False) |
|
self.to_out = nn.Linear(inner_dim, dim, bias=False) |
|
|
|
self._init_rope() |
|
|
|
self.text_position_ids = None |
|
|
|
self.cu_seqlens_q = None |
|
self.cu_seqlens_k = None |
|
|
|
def _init_rope(self): |
|
self.rotary_emb = InternLM2RotaryEmbedding( |
|
self.head_dim, |
|
max_position_embeddings=self.max_position_embeddings, |
|
base=1000000, |
|
) |
|
return self.rotary_emb |
|
|
|
def forward(self, x, media, use_cached_media=False, media_position_ids=None, text_position_ids=None, text_time=None): |
|
h = self.heads |
|
|
|
q = self.to_q(x) |
|
|
|
k, v = self.to_kv(media).chunk(2, dim=-1) |
|
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h) |
|
|
|
if use_cached_media and self.text_position_ids is not None: |
|
text_position_ids = self.text_position_ids[:, -1].unsqueeze(0) |
|
t_cos, t_sin = self.rotary_emb(v, seq_len=(text_position_ids.max().item()+1)) |
|
q = apply_rotary_pos_emb_single(q, t_cos, t_sin, text_position_ids) |
|
else: |
|
t_cos, t_sin = self.rotary_emb(v, seq_len=(text_position_ids.max().item()+1)) |
|
q = apply_rotary_pos_emb_single(q, t_cos, t_sin, text_position_ids) |
|
|
|
|
|
if use_cached_media: |
|
if self.text_position_ids is None: |
|
self.text_position_ids = text_position_ids |
|
next_position_ids = torch.tensor([[self.text_position_ids.shape[1]]], device=self.text_position_ids.device, dtype=self.text_position_ids.dtype) |
|
self.text_position_ids = torch.cat((self.text_position_ids, next_position_ids), dim=1) |
|
|
|
m_cos, m_sin = self.rotary_emb(v, seq_len=(media_position_ids.max().item()+1)) |
|
k = apply_rotary_pos_emb_single(k, m_cos, m_sin, media_position_ids) |
|
|
|
if self.cu_seqlens_k is not None and self.cu_seqlens_q is not None: |
|
|
|
q = q.transpose(1, 2) |
|
k = k.transpose(1, 2) |
|
v = v.transpose(1, 2) |
|
attn_output = self._flash_attention_forward(q, k, v, self.cu_seqlens_q, self.cu_seqlens_k.to(torch.int32)) |
|
attn_output = attn_output.unsqueeze(0).transpose(1, 2) |
|
else: |
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, self.media_attn_mask) |
|
|
|
if text_time is not None: |
|
text_without_media_mask = text_time == 1 |
|
text_without_media_mask = rearrange( |
|
text_without_media_mask, "b i -> b 1 i 1" |
|
) |
|
attn_output = attn_output.masked_fill(text_without_media_mask, 0.0) |
|
|
|
out = rearrange(attn_output, "b h n d -> b n (h d)") |
|
return self.to_out(out) |
|
|
|
def _flash_attention_forward( |
|
self, query_states, key_states, value_states, cu_seqlens_q, cu_seqlens_k, dropout=0.0, softmax_scale=None |
|
): |
|
""" |
|
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token |
|
first unpad the input, then computes the attention scores and pad the final attention scores. |
|
|
|
Args: |
|
query_states (`torch.Tensor`): |
|
Input query states to be passed to Flash Attention API |
|
key_states (`torch.Tensor`): |
|
Input key states to be passed to Flash Attention API |
|
value_states (`torch.Tensor`): |
|
Input value states to be passed to Flash Attention API |
|
attention_mask (`torch.Tensor`): |
|
rename from cu_seqlens to keep compatability - (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths |
|
of the sequences in the batch. |
|
cu_seqlens_q (`torch.Tensor`): |
|
The length of each sequence in the query. |
|
To support data packing based cross-attention computation. |
|
cu_seqlens_k (`torch.Tensor`): |
|
The length of each sequence in the keys. |
|
To support data packing based cross-attention computation. |
|
dropout (`int`, *optional*): |
|
Attention dropout |
|
softmax_scale (`float`, *optional*): |
|
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) |
|
""" |
|
assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1 |
|
query_states = query_states.squeeze(0) |
|
key_states = key_states.squeeze(0) |
|
value_states = value_states.squeeze(0) |
|
cu_seqlens_q = cu_seqlens_q.squeeze(0) |
|
cu_seqlens_k = cu_seqlens_k.squeeze(0) |
|
|
|
with torch.no_grad(): |
|
max_seqlen_q = max([ |
|
cu_seqlens_q[idx+1] - cu_seqlens_q[idx] |
|
for idx in range(cu_seqlens_q.size(0) - 1) |
|
]).item() |
|
|
|
max_seqlen_k = max([ |
|
cu_seqlens_k[idx+1] - cu_seqlens_k[idx] |
|
for idx in range(cu_seqlens_k.size(0) - 1) |
|
]).item() |
|
|
|
|
|
attn_output = flash_attn_varlen_func( |
|
q=query_states, |
|
k=key_states, |
|
v=value_states, |
|
cu_seqlens_q=cu_seqlens_q, |
|
cu_seqlens_k=cu_seqlens_k, |
|
max_seqlen_q=max_seqlen_q, |
|
max_seqlen_k=max_seqlen_k, |
|
dropout_p=dropout, |
|
softmax_scale=softmax_scale, |
|
causal=False, |
|
) |
|
|
|
query_states = query_states.unsqueeze(0) |
|
key_states = key_states.unsqueeze(0) |
|
value_states = value_states.unsqueeze(0) |
|
return attn_output |
|
|
|
class InternLM2MLP(nn.Module): |
|
def __init__(self, config, hidden_act='silu'): |
|
super().__init__() |
|
self.hidden_size = config.language_dim |
|
self.intermediate_size = config.intermediate_size |
|
self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
self.act_fn = ACT2FN[hidden_act] |
|
|
|
def forward(self, x): |
|
down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x)) |
|
|
|
return down_proj |
|
|
|
class GatedCrossAttentionBlock(nn.Module): |
|
def __init__( |
|
self, |
|
config: MixinConfig |
|
): |
|
super().__init__() |
|
dim = config.language_dim |
|
intermediate_size = config.intermediate_size |
|
|
|
self.cross_attention_norm = InternLM2RMSNorm(dim, eps=1e-5) |
|
self.ffn_norm_2 = InternLM2RMSNorm(dim, eps=1e-5) |
|
|
|
self.cross_attn = CrossAttention( |
|
config=config |
|
) |
|
self.attn_gate = nn.Parameter(torch.tensor([0.0])) |
|
self.ffn_2 = InternLM2MLP(config) |
|
self.ff_gate = nn.Parameter(torch.tensor([0.0])) |
|
|
|
self.media = None |
|
|
|
def forward( |
|
self, |
|
x, |
|
media, |
|
use_cached_media=False, |
|
): |
|
residual = x |
|
x = self.cross_attention_norm(x) |
|
media = self.cross_attention_norm(media) |
|
x = ( |
|
self.cross_attn( |
|
x, |
|
media, |
|
use_cached_media=use_cached_media, |
|
media_position_ids=self.cross_attn_media_position_ids, |
|
text_position_ids=self.cross_attn_text_position_ids |
|
) |
|
* self.attn_gate.tanh() |
|
+ residual |
|
) |
|
|
|
residual = x |
|
x = self.ffn_norm_2(x) |
|
x = self.ffn_2(x) * self.ff_gate.tanh() + residual |
|
|
|
return x |
|
|