ruRoberta-large-tsa-news-ru / ruRoberta-large-target-sentiment-classifier-ru-news.py
Vadim121's picture
loading weights and model
a833b96 verified
import torch
import torch.nn as nn
from transformers import RobertaTokenizer, RobertaModel
from transformers.modeling_outputs import SequenceClassifierOutput
class TargetSentimentClassifier(nn.Module):
def __init__(self, model_name, num_labels=3, dropout_rate=0.3, num_dropouts=5, use_multi_sample_dropout=False, device='cpu'):
super().__init__()
self.device = device
self.tokenizer = RobertaTokenizer.from_pretrained(model_name)
self.roberta = RobertaModel.from_pretrained(model_name)
self.use_multi_sample_dropout = use_multi_sample_dropout
self.num_labels = num_labels
if self.use_multi_sample_dropout:
self.dropouts = nn.ModuleList([nn.Dropout(dropout_rate) for _ in range(num_dropouts)])
else:
self.dropout = nn.Dropout(dropout_rate)
self.classifier = nn.Linear(self.roberta.config.hidden_size, num_labels)
self.to(device)
def forward(self, input_ids, attention_mask=None, labels=None):
outputs = self.roberta(input_ids, attention_mask=attention_mask)
sequence_output = outputs.last_hidden_state
cls_output = sequence_output[:, 0, :]
if self.use_multi_sample_dropout:
logits_list = []
for dropout in self.dropouts:
dropped = dropout(cls_output)
logits_list.append(self.classifier(dropped))
avg_logits = torch.mean(torch.stack(logits_list), dim=0)
else:
dropped = self.dropout(cls_output)
avg_logits = self.classifier(dropped)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(avg_logits.view(-1, self.num_labels), labels.view(-1))
return SequenceClassifierOutput(loss=loss, logits=avg_logits)
def prepare_input(self, sentence, target, entity_type):
return sentence.replace(target, f"<en> {target} <|{entity_type}|> </en>")
def predict(self, sentence, target, entity_type):
self.eval()
input_text = self.prepare_input(sentence, target, entity_type)
inputs = self.tokenizer(
input_text,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.forward(**inputs)
logits = outputs.logits
pred = torch.argmax(logits, dim=1).item()
return pred