|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
|
|
device = 'cpu' |
|
|
|
class Attention_Layer(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, enc_out, dec_out, mask): |
|
""" |
|
enc_out: (batch_size, seq_len, enc_dim) |
|
dec_out: (batch_size, dec_dim) |
|
""" |
|
a_t = torch.matmul(enc_out, dec_out.unsqueeze(-1)).squeeze( |
|
-1 |
|
) |
|
|
|
mask = mask.clone() |
|
length = mask.argmin(dim=-1) - 1 |
|
mask[torch.arange(mask.shape[0]), length] = 0 |
|
a_t = a_t.masked_fill(mask == 1, torch.tensor(-1e10)) |
|
|
|
attn_probs = F.softmax(a_t, dim=-1) |
|
|
|
c_t = torch.matmul(attn_probs.unsqueeze(1), enc_out).squeeze( |
|
1 |
|
) |
|
|
|
return c_t, attn_probs |
|
|
|
|
|
class Seq2Seq(nn.Module): |
|
def __init__(self, input_vocab_size, output_vocab_size, emb_dim, hid_size): |
|
super().__init__() |
|
self.hid_size = hid_size |
|
self.input_embedding = nn.Embedding(input_vocab_size, emb_dim) |
|
self.output_embedding = nn.Embedding(output_vocab_size, emb_dim) |
|
self.dropout = nn.Dropout(0.2) |
|
self.encoder = nn.LSTM(emb_dim, hid_size, batch_first=True) |
|
self.decoder = nn.LSTMCell(emb_dim + hid_size, hid_size) |
|
self.attn = Attention_Layer() |
|
self.tanh = nn.Tanh() |
|
self.attentional_layer = nn.Linear(2 * hid_size, hid_size) |
|
self.output_layer = nn.Linear(hid_size, output_vocab_size) |
|
|
|
self.all_encoder_states = None |
|
self.input_mask = None |
|
|
|
def encode(self, x): |
|
x = self.input_embedding(x) |
|
x = self.dropout(x) |
|
all_encoder_states, (last_hidden_state, last_cell_state) = self.encoder( |
|
x |
|
) |
|
|
|
last_hidden_state = last_hidden_state.squeeze(0) |
|
last_cell_state = last_cell_state.squeeze(0) |
|
|
|
all_encoder_states = self.dropout( |
|
all_encoder_states |
|
) |
|
last_hidden_state = self.dropout( |
|
last_hidden_state |
|
) |
|
|
|
self.all_encoder_states = all_encoder_states |
|
|
|
first_attentional_state = torch.zeros( |
|
(x.shape[0], self.hid_size), device=device |
|
) |
|
|
|
return (last_hidden_state, last_cell_state, first_attentional_state) |
|
|
|
def decode_step(self, x, prev_output): |
|
prev_hidden_state, prev_cell_state, prev_attentional_state = prev_output |
|
|
|
x = self.output_embedding(x) |
|
x = torch.cat( |
|
[x, prev_attentional_state], dim=-1 |
|
) |
|
x = self.dropout(x) |
|
|
|
new_hidden_state, new_cell_state = self.decoder( |
|
x, (prev_hidden_state, prev_cell_state) |
|
) |
|
new_hidden_state = self.dropout( |
|
new_hidden_state |
|
) |
|
|
|
c_t, attn_probs = self.attn( |
|
self.all_encoder_states, new_hidden_state, self.input_mask |
|
) |
|
|
|
new_attentional_hid_state = self.tanh( |
|
self.attentional_layer(torch.cat([c_t, new_hidden_state], dim=-1)) |
|
) |
|
|
|
output_logits = self.output_layer( |
|
new_attentional_hid_state |
|
) |
|
|
|
return ( |
|
(new_hidden_state, new_cell_state, new_attentional_hid_state), |
|
output_logits, |
|
attn_probs, |
|
) |
|
|
|
def decode(self, x, encoder_output): |
|
decoder_logits = [] |
|
state = encoder_output |
|
|
|
for i in range(x.shape[1] - 1): |
|
state, output_logits, _ = self.decode_step(x[:, i], state) |
|
decoder_logits.append(output_logits) |
|
|
|
return torch.stack( |
|
decoder_logits, dim=1 |
|
) |
|
|
|
def forward(self, input_tokens, output_tokens, mask): |
|
self.input_mask = mask |
|
|
|
encoder_output = self.encode( |
|
input_tokens |
|
) |
|
decoder_output = self.decode( |
|
output_tokens, encoder_output |
|
) |
|
|
|
return decoder_output |
|
|
|
def translate(self, input_tokens, input_mask, max_len): |
|
""" |
|
input_tokens: (batch_size, seq_len) |
|
max_len: maximum length of the output sequence |
|
""" |
|
self.eval() |
|
self.input_mask = input_mask |
|
|
|
decoder_input = torch.zeros( |
|
input_tokens.shape[0], device=device, dtype=torch.int32 |
|
) |
|
decoder_output = [] |
|
attn_weights = [] |
|
|
|
with torch.no_grad(): |
|
|
|
encoder_output = self.encode(input_tokens) |
|
state = encoder_output |
|
|
|
for _ in range(max_len): |
|
state, output_logits, attn_probs = self.decode_step( |
|
decoder_input, state |
|
) |
|
predicted = output_logits.argmax(dim=-1) |
|
decoder_output.append(predicted) |
|
decoder_input = predicted |
|
attn_weights.append(attn_probs) |
|
|
|
return torch.stack(decoder_output, dim=-1), torch.stack( |
|
attn_weights, dim=1 |
|
) |
|
|