|
import torch |
|
from transformers import RobertaForTokenClassification |
|
from torchcrf import CRF |
|
from src.utils.mapper import configmapper |
|
|
|
|
|
@configmapper.map("models", "roberta_crf_token") |
|
class RobertaLSTMCRF(RobertaForTokenClassification): |
|
def __init__(self, config, lstm_hidden_size, lstm_layers): |
|
super().__init__(config) |
|
self.lstm = torch.nn.LSTM( |
|
input_size=config.hidden_size, |
|
hidden_size=lstm_hidden_size, |
|
num_layers=lstm_layers, |
|
dropout=0.2, |
|
batch_first=True, |
|
bidirectional=True, |
|
) |
|
self.crf = CRF(config.num_labels, batch_first=True) |
|
|
|
del self.classifier |
|
self.classifier = torch.nn.Linear(2 * lstm_hidden_size, config.num_labels) |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
labels=None, |
|
prediction_mask=None, |
|
): |
|
|
|
outputs = self.roberta( |
|
input_ids, |
|
attention_mask, |
|
token_type_ids, |
|
output_hidden_states=True, |
|
return_dict=False, |
|
) |
|
|
|
|
|
sequence_output = outputs[0] |
|
|
|
sequence_output = self.dropout(sequence_output) |
|
|
|
lstm_out, *_ = self.lstm(sequence_output) |
|
sequence_output = self.dropout(lstm_out) |
|
|
|
logits = self.classifier(sequence_output) |
|
|
|
|
|
mask = prediction_mask |
|
mask = mask[:, : logits.size(1)].contiguous() |
|
|
|
|
|
|
|
if labels is not None: |
|
labels = labels[:, : logits.size(1)].contiguous() |
|
loss = -self.crf(logits, labels, mask=mask.bool(), reduction="token_mean") |
|
|
|
tags = self.crf.decode(logits, mask.bool()) |
|
|
|
if labels is not None: |
|
return (loss, logits, tags) |
|
else: |
|
return (logits, tags) |
|
|