File size: 2,188 Bytes
cc9c7ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
# from transformers import BertForTokenClassification
from transformers import ElectraForTokenClassification
from torchcrf import CRF
from src.utils.mapper import configmapper
# import pdb
@configmapper.map("models", "bert_crf_token")
# class BertLSTMCRF(BertForTokenClassification):
class BertLSTMCRF(ElectraForTokenClassification):
def __init__(self, config, lstm_hidden_size, lstm_layers):
super().__init__(config)
# ipdb.set_trace()
self.lstm = torch.nn.LSTM(
input_size=config.hidden_size,
hidden_size=lstm_hidden_size,
num_layers=lstm_layers,
dropout=0.2,
batch_first=True,
bidirectional=True,
)
self.crf = CRF(config.num_labels, batch_first=True)
del self.classifier
self.classifier = torch.nn.Linear(2 * lstm_hidden_size, config.num_labels)
def forward(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
labels=None,
prediction_mask=None,
):
# pdb.set_trace()
# outputs = self.bert(
outputs = self.electra(
input_ids,
attention_mask,
token_type_ids,
output_hidden_states=True,
return_dict=False,
)
# seq_output, all_hidden_states, all_self_attntions, all_cross_attentions
sequence_output = outputs[0] # outputs[1] is pooled output which is none.
sequence_output = self.dropout(sequence_output)
lstm_out, *_ = self.lstm(sequence_output)
sequence_output = self.dropout(lstm_out)
logits = self.classifier(sequence_output)
## CRF
mask = prediction_mask
mask = mask[:, : logits.size(1)].contiguous()
# print(logits)
if labels is not None:
labels = labels[:, : logits.size(1)].contiguous()
loss = -self.crf(logits, labels, mask=mask.bool(), reduction="token_mean")
tags = self.crf.decode(logits, mask.bool())
# print(tags)
if labels is not None:
return (loss, logits, tags)
else:
return (logits, tags)
|