Spaces:
Running
Running
from .base import BaseModel | |
from .bert_modules.bert import BERT | |
import torch.nn as nn | |
class BERTModel(BaseModel): | |
def __init__(self, args): | |
super().__init__(args) | |
self.bert = BERT(args) | |
self.out = nn.Linear(self.bert.hidden, args.num_items + 1) | |
def code(cls): | |
return 'bert' | |
def forward(self, x): | |
x = self.bert(x) | |
return self.out(x) | |