mramazan's picture
Upload 41 files
0edbb0d verified
raw
history blame
1.47 kB
from torch import nn as nn
from models.bert_modules.embedding import BERTEmbedding
from models.bert_modules.transformer import TransformerBlock
from utils import fix_random_seed_as
from pathlib import Path
import pickle
class BERT(nn.Module):
def __init__(self, args):
super().__init__()
fix_random_seed_as(args.model_init_seed)
# self.init_weights()
max_len = args.bert_max_len
num_items = args.num_items
n_layers = args.bert_num_blocks
heads = args.bert_num_heads
vocab_size = num_items + 2
hidden = args.bert_hidden_units
self.hidden = hidden
dropout = args.bert_dropout
# embedding for BERT, sum of positional, segment, token embeddings
self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=self.hidden, max_len=max_len, dropout=dropout)
# multi-layers transformer blocks, deep network
self.transformer_blocks = nn.ModuleList(
[TransformerBlock(hidden, heads, hidden * 4, dropout) for _ in range(n_layers)])
def forward(self, x):
mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
# embedding the indexed sequence to sequence of vectors
x = self.embedding(x)
# running over multiple transformer blocks
for transformer in self.transformer_blocks:
x = transformer.forward(x, mask)
return x
def init_weights(self):
pass