|
import torch.nn as nn |
|
from torch.nn import BCEWithLogitsLoss |
|
from transformers import RobertaModel |
|
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel |
|
from src.utils.mapper import configmapper |
|
|
|
|
|
@configmapper.map("models", "roberta_multi_spans") |
|
class RobertaForMultiSpans(RobertaPreTrainedModel): |
|
def __init__(self, config): |
|
super(RobertaForMultiSpans, self).__init__(config) |
|
self.roberta = RobertaModel(config) |
|
self.num_labels = config.num_labels |
|
|
|
|
|
|
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) |
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
start_positions=None, |
|
end_positions=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
): |
|
outputs = self.roberta( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=None, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
|
|
logits = self.qa_outputs(sequence_output) |
|
start_logits, end_logits = logits.split(1, dim=-1) |
|
start_logits = start_logits.squeeze(-1) |
|
end_logits = end_logits.squeeze(-1) |
|
|
|
|
|
total_loss = None |
|
if ( |
|
start_positions is not None and end_positions is not None |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss_fct = BCEWithLogitsLoss() |
|
|
|
start_loss = loss = loss_fct( |
|
start_logits, |
|
start_positions.float(), |
|
) |
|
end_loss = loss = loss_fct( |
|
end_logits, |
|
end_positions.float(), |
|
) |
|
total_loss = (start_loss + end_loss) / 2 |
|
|
|
output = (start_logits, end_logits) + outputs[2:] |
|
return ((total_loss,) + output) if total_loss is not None else output |