FikriRiyadi commited on
Commit
2611b58
·
verified ·
1 Parent(s): 73bdbb3

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +18 -0
model.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import BertModel
4
+
5
+ class HybridModel(nn.Module):
6
+ def __init__(self, dropout=0.3):
7
+ super(HybridModel, self).__init__()
8
+ self.bert = BertModel.from_pretrained("indobenchmark/indobert-base-p1")
9
+ self.lstm = nn.LSTM(768, 128, bidirectional=True, batch_first=True)
10
+ self.dropout = nn.Dropout(dropout)
11
+ self.classifier = nn.Linear(128 * 2, 10)
12
+
13
+ def forward(self, input_ids, attention_mask):
14
+ with torch.no_grad():
15
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
16
+ lstm_out, _ = self.lstm(outputs.last_hidden_state)
17
+ x = self.dropout(lstm_out[:, -1, :])
18
+ return torch.sigmoid(self.classifier(x))