Spaces:
Running
Running
File size: 414 Bytes
0edbb0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
from .base import BaseModel
from .bert_modules.bert import BERT
import torch.nn as nn
class BERTModel(BaseModel):
def __init__(self, args):
super().__init__(args)
self.bert = BERT(args)
self.out = nn.Linear(self.bert.hidden, args.num_items + 1)
@classmethod
def code(cls):
return 'bert'
def forward(self, x):
x = self.bert(x)
return self.out(x)
|