import torch.nn as nn from transformer import * class Transformer(nn.Module): def __init__(self, src_pad_idx, trg_pad_idx, enc_voc_size, dec_voc_size, d_model, n_head, max_len, ffn_hidden, n_layers, drop_prob, learnable_pos_emb=True): super().__init__() self.src_pad_idx = src_pad_idx self.trg_pad_idx = trg_pad_idx self.encoder = Encoder(d_model=d_model, n_head=n_head, max_len=max_len, ffn_hidden=ffn_hidden, enc_voc_size=enc_voc_size, drop_prob=drop_prob, n_layers=n_layers, padding_idx=src_pad_idx, learnable_pos_emb=learnable_pos_emb) self.decoder = Decoder(d_model=d_model, n_head=n_head, max_len=max_len, ffn_hidden=ffn_hidden, dec_voc_size=dec_voc_size, drop_prob=drop_prob, n_layers=n_layers, padding_idx=trg_pad_idx, learnable_pos_emb=learnable_pos_emb) def get_device(self): return next(self.parameters()).device def forward(self, src, trg): device = self.get_device() src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx).to(device) src_trg_mask = self.make_pad_mask(trg, src, self.trg_pad_idx, self.src_pad_idx).to(device) trg_mask = self.make_pad_mask(trg, trg, self.trg_pad_idx, self.trg_pad_idx).to(device) * \ self.make_no_peak_mask(trg, trg).to(device) enc_src = self.encoder(src, src_mask) output = self.decoder(trg, enc_src, trg_mask, src_trg_mask) return output def make_pad_mask(self, q, k, q_pad_idx, k_pad_idx): len_q, len_k = q.size(1), k.size(1) # batch_size x 1 x 1 x len_k k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2) # batch_size x 1 x len_q x len_k k = k.repeat(1, 1, len_q, 1) # batch_size x 1 x len_q x 1 q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3) # batch_size x 1 x len_q x len_k q = q.repeat(1, 1, 1, len_k) mask = k & q return mask def make_no_peak_mask(self, q, k): len_q, len_k = q.size(1), k.size(1) # len_q x len_k mask = torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor) return mask