|
""" |
|
@author : Hyunwoong |
|
@when : 2019-12-18 |
|
@homepage : https://github.com/gusdnd852 |
|
""" |
|
|
|
import math |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class EncoderLayer(nn.Module): |
|
|
|
def __init__(self, d_model, ffn_hidden, n_head, drop_prob): |
|
super(EncoderLayer, self).__init__() |
|
self.attention = MultiHeadAttention(d_model=d_model, n_head=n_head) |
|
self.norm1 = LayerNorm(d_model=d_model) |
|
self.dropout1 = nn.Dropout(p=drop_prob) |
|
|
|
self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob) |
|
self.norm2 = LayerNorm(d_model=d_model) |
|
self.dropout2 = nn.Dropout(p=drop_prob) |
|
|
|
def forward(self, x, s_mask): |
|
|
|
_x = x |
|
x = self.attention(q=x, k=x, v=x, mask=s_mask) |
|
|
|
|
|
x = self.dropout1(x) |
|
x = self.norm1(x + _x) |
|
|
|
|
|
_x = x |
|
x = self.ffn(x) |
|
|
|
|
|
x = self.dropout2(x) |
|
x = self.norm2(x + _x) |
|
return x |
|
|
|
|
|
class DecoderLayer(nn.Module): |
|
|
|
def __init__(self, d_model, ffn_hidden, n_head, drop_prob): |
|
super(DecoderLayer, self).__init__() |
|
self.self_attention = MultiHeadAttention(d_model=d_model, n_head=n_head) |
|
self.norm1 = LayerNorm(d_model=d_model) |
|
self.dropout1 = nn.Dropout(p=drop_prob) |
|
|
|
self.enc_dec_attention = MultiHeadAttention(d_model=d_model, n_head=n_head) |
|
self.norm2 = LayerNorm(d_model=d_model) |
|
self.dropout2 = nn.Dropout(p=drop_prob) |
|
|
|
self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob) |
|
self.norm3 = LayerNorm(d_model=d_model) |
|
self.dropout3 = nn.Dropout(p=drop_prob) |
|
|
|
def forward(self, dec, enc, t_mask, s_mask): |
|
|
|
_x = dec |
|
x = self.self_attention(q=dec, k=dec, v=dec, mask=t_mask) |
|
|
|
|
|
x = self.dropout1(x) |
|
x = self.norm1(x + _x) |
|
|
|
if enc is not None: |
|
|
|
_x = x |
|
x = self.enc_dec_attention(q=x, k=enc, v=enc, mask=s_mask) |
|
|
|
|
|
x = self.dropout2(x) |
|
x = self.norm2(x + _x) |
|
|
|
|
|
_x = x |
|
x = self.ffn(x) |
|
|
|
|
|
x = self.dropout3(x) |
|
x = self.norm3(x + _x) |
|
return x |
|
|
|
|
|
class ScaleDotProductAttention(nn.Module): |
|
""" |
|
compute scale dot product attention |
|
|
|
Query : given sentence that we focused on (decoder) |
|
Key : every sentence to check relationship with Qeury(encoder) |
|
Value : every sentence same with Key (encoder) |
|
""" |
|
|
|
def __init__(self): |
|
super(ScaleDotProductAttention, self).__init__() |
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
def forward(self, q, k, v, mask=None, e=1e-12): |
|
|
|
|
|
batch_size, head, length, d_tensor = k.size() |
|
|
|
|
|
k_t = k.transpose(2, 3) |
|
score = (q @ k_t) / math.sqrt(d_tensor) |
|
|
|
|
|
if mask is not None: |
|
score = score.masked_fill(mask == 0, -10000) |
|
|
|
|
|
score = self.softmax(score) |
|
|
|
|
|
v = score @ v |
|
|
|
return v, score |
|
|
|
|
|
class PositionwiseFeedForward(nn.Module): |
|
|
|
def __init__(self, d_model, hidden, drop_prob=0.1): |
|
super(PositionwiseFeedForward, self).__init__() |
|
self.linear1 = nn.Linear(d_model, hidden) |
|
self.linear2 = nn.Linear(hidden, d_model) |
|
self.relu = nn.ReLU() |
|
self.dropout = nn.Dropout(p=drop_prob) |
|
|
|
def forward(self, x): |
|
x = self.linear1(x) |
|
x = self.relu(x) |
|
x = self.dropout(x) |
|
x = self.linear2(x) |
|
return x |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
|
|
def __init__(self, d_model, n_head): |
|
super(MultiHeadAttention, self).__init__() |
|
self.n_head = n_head |
|
self.attention = ScaleDotProductAttention() |
|
self.w_q = nn.Linear(d_model, d_model, bias=False) |
|
self.w_k = nn.Linear(d_model, d_model, bias=False) |
|
self.w_v = nn.Linear(d_model, d_model, bias=False) |
|
self.w_concat = nn.Linear(d_model, d_model, bias=False) |
|
|
|
def forward(self, q, k, v, mask=None): |
|
|
|
q, k, v = self.w_q(q), self.w_k(k), self.w_v(v) |
|
|
|
|
|
q, k, v = self.split(q), self.split(k), self.split(v) |
|
|
|
|
|
out, attention = self.attention(q, k, v, mask=mask) |
|
|
|
|
|
out = self.concat(out) |
|
out = self.w_concat(out) |
|
|
|
|
|
|
|
|
|
return out |
|
|
|
def split(self, tensor): |
|
""" |
|
split tensor by number of head |
|
|
|
:param tensor: [batch_size, length, d_model] |
|
:return: [batch_size, head, length, d_tensor] |
|
""" |
|
batch_size, length, d_model = tensor.size() |
|
|
|
d_tensor = d_model // self.n_head |
|
tensor = tensor.view(batch_size, length, self.n_head, d_tensor).transpose(1, 2) |
|
|
|
|
|
return tensor |
|
|
|
def concat(self, tensor): |
|
""" |
|
inverse function of self.split(tensor : torch.Tensor) |
|
|
|
:param tensor: [batch_size, head, length, d_tensor] |
|
:return: [batch_size, length, d_model] |
|
""" |
|
batch_size, head, length, d_tensor = tensor.size() |
|
d_model = head * d_tensor |
|
|
|
tensor = tensor.transpose(1, 2).contiguous().view(batch_size, length, d_model) |
|
return tensor |
|
|
|
|
|
class LayerNorm(nn.Module): |
|
def __init__(self, d_model, eps=1e-12): |
|
super(LayerNorm, self).__init__() |
|
self.gamma = nn.Parameter(torch.ones(d_model)) |
|
self.beta = nn.Parameter(torch.zeros(d_model)) |
|
self.eps = eps |
|
|
|
def forward(self, x): |
|
mean = x.mean(-1, keepdim=True) |
|
var = x.var(-1, unbiased=False, keepdim=True) |
|
|
|
|
|
out = (x - mean) / torch.sqrt(var + self.eps) |
|
out = self.gamma * out + self.beta |
|
return out |
|
|
|
|
|
class TransformerEmbedding(nn.Module): |
|
""" |
|
token embedding + positional encoding (sinusoid) |
|
positional encoding can give positional information to network |
|
""" |
|
|
|
def __init__(self, vocab_size, d_model, max_len, drop_prob, padding_idx, learnable_pos_emb=True): |
|
""" |
|
class for word embedding that included positional information |
|
|
|
:param vocab_size: size of vocabulary |
|
:param d_model: dimensions of model |
|
""" |
|
super(TransformerEmbedding, self).__init__() |
|
self.tok_emb = TokenEmbedding(vocab_size, d_model, padding_idx) |
|
if learnable_pos_emb: |
|
self.pos_emb = LearnablePositionalEncoding(d_model, max_len) |
|
else: |
|
self.pos_emb = SinusoidalPositionalEncoding(d_model, max_len) |
|
self.drop_out = nn.Dropout(p=drop_prob) |
|
|
|
def forward(self, x): |
|
tok_emb = self.tok_emb(x) |
|
pos_emb = self.pos_emb(x).to(tok_emb.device) |
|
return self.drop_out(tok_emb + pos_emb) |
|
|
|
|
|
class TokenEmbedding(nn.Embedding): |
|
""" |
|
Token Embedding using torch.nn |
|
they will dense representation of word using weighted matrix |
|
""" |
|
|
|
def __init__(self, vocab_size, d_model, padding_idx): |
|
""" |
|
class for token embedding that included positional information |
|
|
|
:param vocab_size: size of vocabulary |
|
:param d_model: dimensions of model |
|
""" |
|
super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=padding_idx) |
|
|
|
|
|
class SinusoidalPositionalEncoding(nn.Module): |
|
""" |
|
compute sinusoid encoding. |
|
""" |
|
|
|
def __init__(self, d_model, max_len): |
|
""" |
|
constructor of sinusoid encoding class |
|
|
|
:param d_model: dimension of model |
|
:param max_len: max sequence length |
|
|
|
""" |
|
super(SinusoidalPositionalEncoding, self).__init__() |
|
|
|
|
|
self.encoding = torch.zeros(max_len, d_model) |
|
self.encoding.requires_grad = False |
|
|
|
pos = torch.arange(0, max_len) |
|
pos = pos.float().unsqueeze(dim=1) |
|
|
|
|
|
_2i = torch.arange(0, d_model, step=2).float() |
|
|
|
|
|
|
|
self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model))) |
|
self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model))) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
batch_size, seq_len = x.size() |
|
|
|
|
|
return self.encoding[:seq_len, :] |
|
|
|
|
|
|
|
|
|
class LearnablePositionalEncoding(nn.Module): |
|
""" |
|
compute sinusoid encoding. |
|
""" |
|
|
|
def __init__(self, d_model, max_seq_len): |
|
""" |
|
constructor of learnable positonal encoding class |
|
|
|
:param d_model: dimension of model |
|
:param max_seq_len: max sequence length |
|
|
|
""" |
|
super(LearnablePositionalEncoding, self).__init__() |
|
self.max_seq_len = max_seq_len |
|
self.wpe = nn.Embedding(max_seq_len, d_model) |
|
|
|
def forward(self, x): |
|
|
|
|
|
device = x.device |
|
batch_size, seq_len = x.size() |
|
assert seq_len <= self.max_seq_len, f"Cannot forward sequence of length {seq_len}, max_seq_len is {self.max_seq_len}" |
|
pos = torch.arange(0, seq_len, dtype=torch.long, device=device) |
|
pos_emb = self.wpe(pos) |
|
|
|
return pos_emb |
|
|
|
|
|
|
|
|
|
class Encoder(nn.Module): |
|
|
|
def __init__(self, enc_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob, padding_idx, learnable_pos_emb=True): |
|
super().__init__() |
|
self.emb = TransformerEmbedding(d_model=d_model, |
|
max_len=max_len, |
|
vocab_size=enc_voc_size, |
|
drop_prob=drop_prob, |
|
padding_idx=padding_idx, |
|
learnable_pos_emb=learnable_pos_emb |
|
) |
|
|
|
self.layers = nn.ModuleList([EncoderLayer(d_model=d_model, |
|
ffn_hidden=ffn_hidden, |
|
n_head=n_head, |
|
drop_prob=drop_prob) |
|
for _ in range(n_layers)]) |
|
|
|
def forward(self, x, s_mask): |
|
x = self.emb(x) |
|
|
|
for layer in self.layers: |
|
x = layer(x, s_mask) |
|
|
|
return x |
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, dec_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob, padding_idx, learnable_pos_emb=True): |
|
super().__init__() |
|
self.emb = TransformerEmbedding(d_model=d_model, |
|
drop_prob=drop_prob, |
|
max_len=max_len, |
|
vocab_size=dec_voc_size, |
|
padding_idx=padding_idx, |
|
learnable_pos_emb=learnable_pos_emb |
|
) |
|
|
|
self.layers = nn.ModuleList([DecoderLayer(d_model=d_model, |
|
ffn_hidden=ffn_hidden, |
|
n_head=n_head, |
|
drop_prob=drop_prob) |
|
for _ in range(n_layers)]) |
|
|
|
self.linear = nn.Linear(d_model, dec_voc_size) |
|
|
|
def forward(self, trg, enc_src, trg_mask, src_mask): |
|
trg = self.emb(trg) |
|
|
|
for layer in self.layers: |
|
trg = layer(trg, enc_src, trg_mask, src_mask) |
|
|
|
|
|
output = self.linear(trg) |
|
return output |
|
|
|
class Transformer(nn.Module): |
|
|
|
def __init__(self, src_pad_idx, trg_pad_idx, enc_voc_size, dec_voc_size, d_model, n_head, max_len, |
|
ffn_hidden, n_layers, drop_prob, learnable_pos_emb=True): |
|
super().__init__() |
|
self.src_pad_idx = src_pad_idx |
|
self.trg_pad_idx = trg_pad_idx |
|
self.encoder = Encoder(d_model=d_model, |
|
n_head=n_head, |
|
max_len=max_len, |
|
ffn_hidden=ffn_hidden, |
|
enc_voc_size=enc_voc_size, |
|
drop_prob=drop_prob, |
|
n_layers=n_layers, |
|
padding_idx=src_pad_idx, |
|
learnable_pos_emb=learnable_pos_emb) |
|
|
|
self.decoder = Decoder(d_model=d_model, |
|
n_head=n_head, |
|
max_len=max_len, |
|
ffn_hidden=ffn_hidden, |
|
dec_voc_size=dec_voc_size, |
|
drop_prob=drop_prob, |
|
n_layers=n_layers, |
|
padding_idx=trg_pad_idx, |
|
learnable_pos_emb=learnable_pos_emb) |
|
|
|
def get_device(self): |
|
return next(self.parameters()).device |
|
|
|
def forward(self, src, trg): |
|
device = self.get_device() |
|
src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx).to(device) |
|
src_trg_mask = self.make_pad_mask(trg, src, self.trg_pad_idx, self.src_pad_idx).to(device) |
|
trg_mask = self.make_pad_mask(trg, trg, self.trg_pad_idx, self.trg_pad_idx).to(device) * \ |
|
self.make_no_peak_mask(trg, trg).to(device) |
|
|
|
|
|
|
|
|
|
enc_src = self.encoder(src, src_mask) |
|
output = self.decoder(trg, enc_src, trg_mask, src_trg_mask) |
|
return output |
|
|
|
def make_pad_mask(self, q, k, q_pad_idx, k_pad_idx): |
|
len_q, len_k = q.size(1), k.size(1) |
|
|
|
|
|
k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2) |
|
|
|
k = k.repeat(1, 1, len_q, 1) |
|
|
|
|
|
q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3) |
|
|
|
q = q.repeat(1, 1, 1, len_k) |
|
|
|
mask = k & q |
|
return mask |
|
|
|
def make_no_peak_mask(self, q, k): |
|
len_q, len_k = q.size(1), k.size(1) |
|
|
|
|
|
mask = torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor) |
|
|
|
return mask |
|
|
|
|
|
def make_pad_mask(x, pad_idx): |
|
q = k = x |
|
q_pad_idx = k_pad_idx = pad_idx |
|
len_q, len_k = q.size(1), k.size(1) |
|
|
|
|
|
k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2) |
|
|
|
k = k.repeat(1, 1, len_q, 1) |
|
|
|
|
|
q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3) |
|
|
|
q = q.repeat(1, 1, 1, len_k) |
|
|
|
mask = k & q |
|
return mask |
|
|
|
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
def pad_seq_v2(sequences, batch_first=True, padding_value=0.0, prepadding=True): |
|
lens = [i.shape[0]for i in sequences] |
|
padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=padding_value) |
|
if prepadding: |
|
for i in range(len(lens)): |
|
padded_sequences[i] = padded_sequences[i].roll(-lens[i]) |
|
if not batch_first: |
|
padded_sequences = padded_sequences.transpose(0, 1) |
|
return padded_sequences |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
import torch |
|
import random |
|
import numpy as np |
|
|
|
rand_seed = 10 |
|
|
|
device = 'cpu' |
|
|
|
|
|
batch_size = 128 |
|
max_len = 256 |
|
d_model = 512 |
|
n_layers = 3 |
|
n_heads = 16 |
|
ffn_hidden = 2048 |
|
drop_prob = 0.1 |
|
|
|
|
|
init_lr = 1e-5 |
|
factor = 0.9 |
|
adam_eps = 5e-9 |
|
patience = 10 |
|
warmup = 100 |
|
epoch = 1000 |
|
clip = 1.0 |
|
weight_decay = 5e-4 |
|
inf = float('inf') |
|
|
|
src_pad_idx = 2 |
|
trg_pad_idx = 3 |
|
|
|
enc_voc_size = 37 |
|
dec_voc_size = 15 |
|
model = Transformer(src_pad_idx=src_pad_idx, |
|
trg_pad_idx=trg_pad_idx, |
|
d_model=d_model, |
|
enc_voc_size=enc_voc_size, |
|
dec_voc_size=dec_voc_size, |
|
max_len=max_len, |
|
ffn_hidden=ffn_hidden, |
|
n_head=n_heads, |
|
n_layers=n_layers, |
|
drop_prob=drop_prob |
|
).to(device) |
|
|
|
random.seed(rand_seed) |
|
|
|
np.random.seed(rand_seed) |
|
torch.manual_seed(rand_seed) |
|
|
|
x_list = [ |
|
torch.tensor([[1, 1]]).transpose(0, 1), |
|
torch.tensor([[1, 1, 1, 1, 1, 1, 1]]).transpose(0, 1), |
|
torch.tensor([[1, 1, 1]]).transpose(0, 1) |
|
] |
|
|
|
|
|
src_pad_idx = model.src_pad_idx |
|
trg_pad_idx = model.trg_pad_idx |
|
|
|
src = pad_seq_v2(x_list, padding_value=src_pad_idx, prepadding=False).squeeze(2) |
|
trg = pad_seq_v2(x_list, padding_value=trg_pad_idx, prepadding=False).squeeze(2) |
|
out = model(src, trg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|