add model
Browse files- config.json +0 -0
- model.py +5 -1
config.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
model.py
CHANGED
@@ -29,6 +29,7 @@ class BertMesh(PreTrainedModel):
|
|
29 |
self.hidden_size = getattr(self.config, "hidden_size", 512)
|
30 |
self.dropout = getattr(self.config, "dropout", 0.1)
|
31 |
self.multilabel_attention = getattr(self.config, "multilabel_attention", False)
|
|
|
32 |
|
33 |
self.bert = AutoModel.from_pretrained(self.pretrained_model) # 768
|
34 |
self.multilabel_attention_layer = MultiLabelAttention(
|
@@ -39,7 +40,7 @@ class BertMesh(PreTrainedModel):
|
|
39 |
self.linear_out = torch.nn.Linear(self.hidden_size, self.num_labels)
|
40 |
self.dropout_layer = torch.nn.Dropout(self.dropout)
|
41 |
|
42 |
-
def forward(self, input_ids, **kwargs):
|
43 |
if type(input_ids) is list:
|
44 |
# coming from tokenizer
|
45 |
input_ids = torch.tensor(input_ids)
|
@@ -55,4 +56,7 @@ class BertMesh(PreTrainedModel):
|
|
55 |
outs = torch.nn.functional.relu(self.linear_1(cls))
|
56 |
outs = self.dropout_layer(outs)
|
57 |
outs = torch.sigmoid(self.linear_out(outs))
|
|
|
|
|
|
|
58 |
return outs
|
|
|
29 |
self.hidden_size = getattr(self.config, "hidden_size", 512)
|
30 |
self.dropout = getattr(self.config, "dropout", 0.1)
|
31 |
self.multilabel_attention = getattr(self.config, "multilabel_attention", False)
|
32 |
+
self.id2label = self.config.id2label
|
33 |
|
34 |
self.bert = AutoModel.from_pretrained(self.pretrained_model) # 768
|
35 |
self.multilabel_attention_layer = MultiLabelAttention(
|
|
|
40 |
self.linear_out = torch.nn.Linear(self.hidden_size, self.num_labels)
|
41 |
self.dropout_layer = torch.nn.Dropout(self.dropout)
|
42 |
|
43 |
+
def forward(self, input_ids, return_labels=False, **kwargs):
|
44 |
if type(input_ids) is list:
|
45 |
# coming from tokenizer
|
46 |
input_ids = torch.tensor(input_ids)
|
|
|
56 |
outs = torch.nn.functional.relu(self.linear_1(cls))
|
57 |
outs = self.dropout_layer(outs)
|
58 |
outs = torch.sigmoid(self.linear_out(outs))
|
59 |
+
if return_labels:
|
60 |
+
# TODO Vectorize
|
61 |
+
outs = [[self.id2label[label_id] for label_id, label_prob in enumerate(out) if label_prob > 0.5] for out in outs]
|
62 |
return outs
|