|
|
|
from typing import Optional, Tuple |
|
import torch |
|
import torch.nn as nn |
|
from transformers import RobertaModel, RobertaPreTrainedModel |
|
|
|
class MyRobertaForSequenceClassification(RobertaPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = getattr(config, "num_labels", 4) |
|
self.roberta = RobertaModel(config, add_pooling_layer=False) |
|
self.classifier = nn.Linear(config.hidden_size, self.num_labels) |
|
self.loss_fn = nn.CrossEntropyLoss() |
|
self.post_init() |
|
|
|
def _get_pad_id(self, input_ids): |
|
|
|
return 1 |
|
|
|
def _pool(self, last_hidden_state, attention_mask, model_type: str = "roberta"): |
|
|
|
if last_hidden_state.dim() == 3: |
|
if model_type in {"bert", "roberta", "deberta", "xlm-roberta", "electra"}: |
|
return last_hidden_state[:, 0, :] |
|
mask = attention_mask.unsqueeze(-1).float() |
|
summed = (last_hidden_state * mask).sum(dim=1) |
|
denom = mask.sum(dim=1).clamp(min=1e-9) |
|
return summed / denom |
|
elif last_hidden_state.dim() == 2: |
|
return last_hidden_state |
|
else: |
|
raise ValueError(f"Unexpected hidden dim: {last_hidden_state.dim()}") |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
labels: Optional[torch.LongTensor] = None, |
|
**kwargs |
|
): |
|
if attention_mask is None and input_ids is not None: |
|
pad_id = self._get_pad_id(input_ids) |
|
attention_mask = input_ids.ne(pad_id) |
|
|
|
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) |
|
hidden = outputs.last_hidden_state |
|
pooled = self._pool(hidden, attention_mask, "roberta") |
|
logits = self.classifier(pooled) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_fn(logits, labels.long()) |
|
|
|
return {"loss": loss, "logits": logits} |
|
|