File size: 1,166 Bytes
a8700fa
be9811f
 
 
 
 
 
a8700fa
be9811f
 
 
 
 
 
 
ee651c6
 
 
 
 
be9811f
 
ee651c6
be9811f
 
 
 
 
 
 
a8700fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from torch.nn import CrossEntropyLoss
from transformers import AutoModel, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput

from .configuration_roberta_emotion import RobertaEmotionConfig


class RobertaEmotion(PreTrainedModel):
    config_class = RobertaEmotionConfig

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.backbone = AutoModel.from_pretrained("roberta-base", config)
        self.classifier = torch.nn.Sequential(
            torch.nn.Dropout(p=0.1),
            torch.nn.Linear(config.hidden_size, config.num_labels)
        )
        torch.nn.init.xavier_normal_(self.classifier[1].weight)

    def forward(self, input_ids, labels=None, attention_mask=None):
        logits = self.classifier(self.backbone(input_ids).last_hidden_state[:, 0, :])

        loss = None
        if labels is not None:
            labels = labels.to(logits.device)
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return SequenceClassifierOutput(loss=loss, logits=logits)