Spaces:
Running
Running
File size: 1,469 Bytes
0edbb0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from torch import nn as nn
from models.bert_modules.embedding import BERTEmbedding
from models.bert_modules.transformer import TransformerBlock
from utils import fix_random_seed_as
from pathlib import Path
import pickle
class BERT(nn.Module):
def __init__(self, args):
super().__init__()
fix_random_seed_as(args.model_init_seed)
# self.init_weights()
max_len = args.bert_max_len
num_items = args.num_items
n_layers = args.bert_num_blocks
heads = args.bert_num_heads
vocab_size = num_items + 2
hidden = args.bert_hidden_units
self.hidden = hidden
dropout = args.bert_dropout
# embedding for BERT, sum of positional, segment, token embeddings
self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=self.hidden, max_len=max_len, dropout=dropout)
# multi-layers transformer blocks, deep network
self.transformer_blocks = nn.ModuleList(
[TransformerBlock(hidden, heads, hidden * 4, dropout) for _ in range(n_layers)])
def forward(self, x):
mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
# embedding the indexed sequence to sequence of vectors
x = self.embedding(x)
# running over multiple transformer blocks
for transformer in self.transformer_blocks:
x = transformer.forward(x, mask)
return x
def init_weights(self):
pass
|