nsorros commited on
Commit
731f527
1 Parent(s): 6e2698b
Files changed (2) hide show
  1. config.json +0 -0
  2. model.py +5 -1
config.json CHANGED
The diff for this file is too large to render. See raw diff
 
model.py CHANGED
@@ -29,6 +29,7 @@ class BertMesh(PreTrainedModel):
29
  self.hidden_size = getattr(self.config, "hidden_size", 512)
30
  self.dropout = getattr(self.config, "dropout", 0.1)
31
  self.multilabel_attention = getattr(self.config, "multilabel_attention", False)
 
32
 
33
  self.bert = AutoModel.from_pretrained(self.pretrained_model) # 768
34
  self.multilabel_attention_layer = MultiLabelAttention(
@@ -39,7 +40,7 @@ class BertMesh(PreTrainedModel):
39
  self.linear_out = torch.nn.Linear(self.hidden_size, self.num_labels)
40
  self.dropout_layer = torch.nn.Dropout(self.dropout)
41
 
42
- def forward(self, input_ids, **kwargs):
43
  if type(input_ids) is list:
44
  # coming from tokenizer
45
  input_ids = torch.tensor(input_ids)
@@ -55,4 +56,7 @@ class BertMesh(PreTrainedModel):
55
  outs = torch.nn.functional.relu(self.linear_1(cls))
56
  outs = self.dropout_layer(outs)
57
  outs = torch.sigmoid(self.linear_out(outs))
 
 
 
58
  return outs
 
29
  self.hidden_size = getattr(self.config, "hidden_size", 512)
30
  self.dropout = getattr(self.config, "dropout", 0.1)
31
  self.multilabel_attention = getattr(self.config, "multilabel_attention", False)
32
+ self.id2label = self.config.id2label
33
 
34
  self.bert = AutoModel.from_pretrained(self.pretrained_model) # 768
35
  self.multilabel_attention_layer = MultiLabelAttention(
 
40
  self.linear_out = torch.nn.Linear(self.hidden_size, self.num_labels)
41
  self.dropout_layer = torch.nn.Dropout(self.dropout)
42
 
43
+ def forward(self, input_ids, return_labels=False, **kwargs):
44
  if type(input_ids) is list:
45
  # coming from tokenizer
46
  input_ids = torch.tensor(input_ids)
 
56
  outs = torch.nn.functional.relu(self.linear_1(cls))
57
  outs = self.dropout_layer(outs)
58
  outs = torch.sigmoid(self.linear_out(outs))
59
+ if return_labels:
60
+ # TODO Vectorize
61
+ outs = [[self.id2label[label_id] for label_id, label_prob in enumerate(out) if label_prob > 0.5] for out in outs]
62
  return outs