Spaces:
Build error
Build error
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from layers.fc import MLP, FC | |
from layers.layer_norm import LayerNorm | |
# ------------------------------------ | |
# ---------- Masking sequence -------- | |
# ------------------------------------ | |
def make_mask(feature): | |
return (torch.sum(torch.abs(feature), dim=-1) == 0).unsqueeze(1).unsqueeze(2) | |
# ------------------------------ | |
# ---------- Flattening -------- | |
# ------------------------------ | |
class AttFlat(nn.Module): | |
def __init__(self, args, flat_glimpse, merge=False): | |
super(AttFlat, self).__init__() | |
self.args = args | |
self.merge = merge | |
self.flat_glimpse = flat_glimpse | |
self.mlp = MLP( | |
in_size=args.hidden_size, | |
mid_size=args.ff_size, | |
out_size=flat_glimpse, | |
dropout_r=args.dropout_r, | |
use_relu=True, | |
) | |
if self.merge: | |
self.linear_merge = nn.Linear( | |
args.hidden_size * flat_glimpse, args.hidden_size * 2 | |
) | |
def forward(self, x, x_mask): | |
att = self.mlp(x) | |
if x_mask is not None: | |
att = att.masked_fill(x_mask.squeeze(1).squeeze(1).unsqueeze(2), -1e9) | |
att = F.softmax(att, dim=1) | |
att_list = [] | |
for i in range(self.flat_glimpse): | |
att_list.append(torch.sum(att[:, :, i : i + 1] * x, dim=1)) | |
if self.merge: | |
x_atted = torch.cat(att_list, dim=1) | |
x_atted = self.linear_merge(x_atted) | |
return x_atted | |
return torch.stack(att_list).transpose_(0, 1) | |
# ------------------------ | |
# ---- Self Attention ---- | |
# ------------------------ | |
class SA(nn.Module): | |
def __init__(self, args): | |
super(SA, self).__init__() | |
self.mhatt = MHAtt(args) | |
self.ffn = FFN(args) | |
self.dropout1 = nn.Dropout(args.dropout_r) | |
self.norm1 = LayerNorm(args.hidden_size) | |
self.dropout2 = nn.Dropout(args.dropout_r) | |
self.norm2 = LayerNorm(args.hidden_size) | |
def forward(self, y, y_mask): | |
y = self.norm1(y + self.dropout1(self.mhatt(y, y, y, y_mask))) | |
y = self.norm2(y + self.dropout2(self.ffn(y))) | |
return y | |
# ------------------------------- | |
# ---- Self Guided Attention ---- | |
# ------------------------------- | |
class SGA(nn.Module): | |
def __init__(self, args): | |
super(SGA, self).__init__() | |
self.mhatt1 = MHAtt(args) | |
self.mhatt2 = MHAtt(args) | |
self.ffn = FFN(args) | |
self.dropout1 = nn.Dropout(args.dropout_r) | |
self.norm1 = LayerNorm(args.hidden_size) | |
self.dropout2 = nn.Dropout(args.dropout_r) | |
self.norm2 = LayerNorm(args.hidden_size) | |
self.dropout3 = nn.Dropout(args.dropout_r) | |
self.norm3 = LayerNorm(args.hidden_size) | |
def forward(self, x, y, x_mask, y_mask): | |
x = self.norm1(x + self.dropout1(self.mhatt1(v=x, k=x, q=x, mask=x_mask))) | |
x = self.norm2(x + self.dropout2(self.mhatt2(v=y, k=y, q=x, mask=y_mask))) | |
x = self.norm3(x + self.dropout3(self.ffn(x))) | |
return x | |
# ------------------------------ | |
# ---- Multi-Head Attention ---- | |
# ------------------------------ | |
class MHAtt(nn.Module): | |
def __init__(self, args): | |
super(MHAtt, self).__init__() | |
self.args = args | |
self.linear_v = nn.Linear(args.hidden_size, args.hidden_size) | |
self.linear_k = nn.Linear(args.hidden_size, args.hidden_size) | |
self.linear_q = nn.Linear(args.hidden_size, args.hidden_size) | |
self.linear_merge = nn.Linear(args.hidden_size, args.hidden_size) | |
self.dropout = nn.Dropout(args.dropout_r) | |
def forward(self, v, k, q, mask): | |
n_batches = q.size(0) | |
v = ( | |
self.linear_v(v) | |
.view( | |
n_batches, | |
-1, | |
self.args.multi_head, | |
int(self.args.hidden_size / self.args.multi_head), | |
) | |
.transpose(1, 2) | |
) | |
k = ( | |
self.linear_k(k) | |
.view( | |
n_batches, | |
-1, | |
self.args.multi_head, | |
int(self.args.hidden_size / self.args.multi_head), | |
) | |
.transpose(1, 2) | |
) | |
q = ( | |
self.linear_q(q) | |
.view( | |
n_batches, | |
-1, | |
self.args.multi_head, | |
int(self.args.hidden_size / self.args.multi_head), | |
) | |
.transpose(1, 2) | |
) | |
atted = self.att(v, k, q, mask) | |
atted = ( | |
atted.transpose(1, 2) | |
.contiguous() | |
.view(n_batches, -1, self.args.hidden_size) | |
) | |
atted = self.linear_merge(atted) | |
return atted | |
def att(self, value, key, query, mask): | |
d_k = query.size(-1) | |
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) | |
if mask is not None: | |
scores = scores.masked_fill(mask, -1e9) | |
att_map = F.softmax(scores, dim=-1) | |
att_map = self.dropout(att_map) | |
return torch.matmul(att_map, value) | |
# --------------------------- | |
# ---- Feed Forward Nets ---- | |
# --------------------------- | |
class FFN(nn.Module): | |
def __init__(self, args): | |
super(FFN, self).__init__() | |
self.mlp = MLP( | |
in_size=args.hidden_size, | |
mid_size=args.ff_size, | |
out_size=args.hidden_size, | |
dropout_r=args.dropout_r, | |
use_relu=True, | |
) | |
def forward(self, x): | |
return self.mlp(x) | |
# --------------------------- | |
# ---- FF + norm ----------- | |
# --------------------------- | |
class FFAndNorm(nn.Module): | |
def __init__(self, args): | |
super(FFAndNorm, self).__init__() | |
self.ffn = FFN(args) | |
self.norm1 = LayerNorm(args.hidden_size) | |
self.dropout2 = nn.Dropout(args.dropout_r) | |
self.norm2 = LayerNorm(args.hidden_size) | |
def forward(self, x): | |
x = self.norm1(x) | |
x = self.norm2(x + self.dropout2(self.ffn(x))) | |
return x | |
class Block(nn.Module): | |
def __init__(self, args, i): | |
super(Block, self).__init__() | |
self.args = args | |
self.sa1 = SA(args) | |
self.sa3 = SGA(args) | |
self.last = i == args.layer - 1 | |
if not self.last: | |
self.att_lang = AttFlat(args, args.lang_seq_len, merge=False) | |
self.att_audio = AttFlat(args, args.audio_seq_len, merge=False) | |
self.norm_l = LayerNorm(args.hidden_size) | |
self.norm_i = LayerNorm(args.hidden_size) | |
self.dropout = nn.Dropout(args.dropout_r) | |
def forward(self, x, x_mask, y, y_mask): | |
ax = self.sa1(x, x_mask) | |
ay = self.sa3(y, x, y_mask, x_mask) | |
x = ax + x | |
y = ay + y | |
if self.last: | |
return x, y | |
ax = self.att_lang(x, x_mask) | |
ay = self.att_audio(y, y_mask) | |
return self.norm_l(x + self.dropout(ax)), self.norm_i(y + self.dropout(ay)) | |
class Model_LA(nn.Module): | |
def __init__(self, args, vocab_size, pretrained_emb): | |
super(Model_LA, self).__init__() | |
self.args = args | |
# LSTM | |
self.embedding = nn.Embedding( | |
num_embeddings=vocab_size, embedding_dim=args.word_embed_size | |
) | |
# Loading the GloVe embedding weights | |
self.embedding.weight.data.copy_(torch.from_numpy(pretrained_emb)) | |
self.lstm_x = nn.LSTM( | |
input_size=args.word_embed_size, | |
hidden_size=args.hidden_size, | |
num_layers=1, | |
batch_first=True, | |
) | |
# self.lstm_y = nn.LSTM( | |
# input_size=args.audio_feat_size, | |
# hidden_size=args.hidden_size, | |
# num_layers=1, | |
# batch_first=True | |
# ) | |
# Feature size to hid size | |
self.adapter = nn.Linear(args.audio_feat_size, args.hidden_size) | |
# Encoder blocks | |
self.enc_list = nn.ModuleList([Block(args, i) for i in range(args.layer)]) | |
# Flattenting features before proj | |
self.attflat_img = AttFlat(args, 1, merge=True) | |
self.attflat_lang = AttFlat(args, 1, merge=True) | |
# Classification layers | |
self.proj_norm = LayerNorm(2 * args.hidden_size) | |
self.proj = self.proj = nn.Linear(2 * args.hidden_size, args.ans_size) | |
def forward(self, x, y, _): | |
x_mask = make_mask(x.unsqueeze(2)) | |
y_mask = make_mask(y) | |
embedding = self.embedding(x) | |
x, _ = self.lstm_x(embedding) | |
# y, _ = self.lstm_y(y) | |
y = self.adapter(y) | |
for i, dec in enumerate(self.enc_list): | |
x_m, x_y = None, None | |
if i == 0: | |
x_m, x_y = x_mask, y_mask | |
x, y = dec(x, x_m, y, x_y) | |
x = self.attflat_lang(x, None) | |
y = self.attflat_img(y, None) | |
# Classification layers | |
proj_feat = x + y | |
proj_feat = self.proj_norm(proj_feat) | |
ans = self.proj(proj_feat) | |
return ans | |