|
import torch
|
|
import torch.nn as nn
|
|
from transformers import RobertaTokenizer, RobertaModel
|
|
from transformers.modeling_outputs import SequenceClassifierOutput
|
|
|
|
class TargetSentimentClassifier(nn.Module):
|
|
def __init__(self, model_name, num_labels=3, dropout_rate=0.3, num_dropouts=5, use_multi_sample_dropout=False, device='cpu'):
|
|
super().__init__()
|
|
self.device = device
|
|
self.tokenizer = RobertaTokenizer.from_pretrained(model_name)
|
|
self.roberta = RobertaModel.from_pretrained(model_name)
|
|
self.use_multi_sample_dropout = use_multi_sample_dropout
|
|
self.num_labels = num_labels
|
|
|
|
if self.use_multi_sample_dropout:
|
|
self.dropouts = nn.ModuleList([nn.Dropout(dropout_rate) for _ in range(num_dropouts)])
|
|
else:
|
|
self.dropout = nn.Dropout(dropout_rate)
|
|
|
|
self.classifier = nn.Linear(self.roberta.config.hidden_size, num_labels)
|
|
self.to(device)
|
|
|
|
def forward(self, input_ids, attention_mask=None, labels=None):
|
|
outputs = self.roberta(input_ids, attention_mask=attention_mask)
|
|
sequence_output = outputs.last_hidden_state
|
|
cls_output = sequence_output[:, 0, :]
|
|
|
|
if self.use_multi_sample_dropout:
|
|
logits_list = []
|
|
for dropout in self.dropouts:
|
|
dropped = dropout(cls_output)
|
|
logits_list.append(self.classifier(dropped))
|
|
avg_logits = torch.mean(torch.stack(logits_list), dim=0)
|
|
else:
|
|
dropped = self.dropout(cls_output)
|
|
avg_logits = self.classifier(dropped)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
loss = loss_fct(avg_logits.view(-1, self.num_labels), labels.view(-1))
|
|
|
|
return SequenceClassifierOutput(loss=loss, logits=avg_logits)
|
|
|
|
def prepare_input(self, sentence, target, entity_type):
|
|
return sentence.replace(target, f"<en> {target} <|{entity_type}|> </en>")
|
|
|
|
def predict(self, sentence, target, entity_type):
|
|
self.eval()
|
|
input_text = self.prepare_input(sentence, target, entity_type)
|
|
inputs = self.tokenizer(
|
|
input_text,
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=512,
|
|
return_tensors="pt"
|
|
).to(self.device)
|
|
|
|
with torch.no_grad():
|
|
outputs = self.forward(**inputs)
|
|
logits = outputs.logits
|
|
pred = torch.argmax(logits, dim=1).item()
|
|
return pred
|
|
|