import torch import torch.nn as nn import torch.nn.functional as F from timm.models.layers import DropPath class BiMultiHeadAttention(nn.Module): def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None): super(BiMultiHeadAttention, self).__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.v_dim = v_dim self.l_dim = l_dim assert ( self.head_dim * self.num_heads == self.embed_dim ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." self.scale = self.head_dim ** (-0.5) self.dropout = dropout self.v_proj = nn.Linear(self.v_dim, self.embed_dim) self.l_proj = nn.Linear(self.l_dim, self.embed_dim) self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim) self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim) self.stable_softmax_2d = True self.clamp_min_for_underflow = True self.clamp_max_for_overflow = True self._reset_parameters() def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def _reset_parameters(self): nn.init.xavier_uniform_(self.v_proj.weight) self.v_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.l_proj.weight) self.l_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.values_l_proj.weight) self.values_l_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.out_v_proj.weight) self.out_v_proj.bias.data.fill_(0) def forward(self, v, l, attention_mask_v=None, attention_mask_l=None): bsz, tgt_len, _ = v.size() query_states = self.v_proj(v) * self.scale key_states = self._shape(self.l_proj(l), -1, bsz) value_l_states = self._shape(self.values_l_proj(l), -1, bsz) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_l_states = value_l_states.view(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" ) if self.stable_softmax_2d: attn_weights = attn_weights - attn_weights.max() if self.clamp_min_for_underflow: attn_weights = torch.clamp( attn_weights, min=-50000 ) # Do not increase -50000, data type half has quite limited range if self.clamp_max_for_overflow: attn_weights = torch.clamp( attn_weights, max=50000 ) # Do not increase 50000, data type half has quite limited range attn_weights_v = attn_weights.softmax(dim=-1) attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training) attn_output_v = torch.bmm(attn_probs_v, value_l_states) if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}" ) attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output_v = attn_output_v.transpose(1, 2) attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim) attn_output_v = self.out_v_proj(attn_output_v) return attn_output_v # Bi-Direction MHA (text->image, image->text) class BiAttentionBlock(nn.Module): def __init__( self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, drop_path=0.0, cfg=None, ): super(BiAttentionBlock, self).__init__() # pre layer norm self.layer_norm_v = nn.LayerNorm(v_dim) self.layer_norm_l = nn.LayerNorm(l_dim) self.attn = BiMultiHeadAttention( v_dim=v_dim, l_dim=l_dim, embed_dim=embed_dim, num_heads=num_heads, dropout=dropout ) # add layer scale for training stability self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, v, l, attention_mask_v=None, attention_mask_l=None): v = self.layer_norm_v(v) l = self.layer_norm_l(l) delta_v = self.attn( v, l, attention_mask_v=attention_mask_v, attention_mask_l=attention_mask_l ) delta_v = self.drop_path(delta_v) return delta_v