|
|
|
import torch.nn as nn |
|
from transformer import * |
|
|
|
class Transformer(nn.Module): |
|
|
|
def __init__(self, src_pad_idx, trg_pad_idx, enc_voc_size, dec_voc_size, d_model, n_head, max_len, |
|
ffn_hidden, n_layers, drop_prob, learnable_pos_emb=True): |
|
super().__init__() |
|
self.src_pad_idx = src_pad_idx |
|
self.trg_pad_idx = trg_pad_idx |
|
self.encoder = Encoder(d_model=d_model, |
|
n_head=n_head, |
|
max_len=max_len, |
|
ffn_hidden=ffn_hidden, |
|
enc_voc_size=enc_voc_size, |
|
drop_prob=drop_prob, |
|
n_layers=n_layers, |
|
padding_idx=src_pad_idx, |
|
learnable_pos_emb=learnable_pos_emb) |
|
|
|
self.decoder = Decoder(d_model=d_model, |
|
n_head=n_head, |
|
max_len=max_len, |
|
ffn_hidden=ffn_hidden, |
|
dec_voc_size=dec_voc_size, |
|
drop_prob=drop_prob, |
|
n_layers=n_layers, |
|
padding_idx=trg_pad_idx, |
|
learnable_pos_emb=learnable_pos_emb) |
|
|
|
def get_device(self): |
|
return next(self.parameters()).device |
|
|
|
def forward(self, src, trg): |
|
device = self.get_device() |
|
src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx).to(device) |
|
src_trg_mask = self.make_pad_mask(trg, src, self.trg_pad_idx, self.src_pad_idx).to(device) |
|
trg_mask = self.make_pad_mask(trg, trg, self.trg_pad_idx, self.trg_pad_idx).to(device) * \ |
|
self.make_no_peak_mask(trg, trg).to(device) |
|
enc_src = self.encoder(src, src_mask) |
|
output = self.decoder(trg, enc_src, trg_mask, src_trg_mask) |
|
return output |
|
|
|
def make_pad_mask(self, q, k, q_pad_idx, k_pad_idx): |
|
len_q, len_k = q.size(1), k.size(1) |
|
|
|
k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2) |
|
|
|
k = k.repeat(1, 1, len_q, 1) |
|
|
|
q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3) |
|
|
|
q = q.repeat(1, 1, 1, len_k) |
|
mask = k & q |
|
return mask |
|
|
|
def make_no_peak_mask(self, q, k): |
|
len_q, len_k = q.size(1), k.size(1) |
|
|
|
mask = torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor) |
|
return mask |
|
|
|
|