mramazan's picture
Upload 60 files
426ffb5 verified
raw
history blame
414 Bytes
from .base import BaseModel
from .bert_modules.bert import BERT
import torch.nn as nn
class BERTModel(BaseModel):
def __init__(self, args):
super().__init__(args)
self.bert = BERT(args)
self.out = nn.Linear(self.bert.hidden, args.num_items + 1)
@classmethod
def code(cls):
return 'bert'
def forward(self, x):
x = self.bert(x)
return self.out(x)