from transformers import T5EncoderModel, T5Config, PreTrainedModel | |
import torch.nn as nn | |
import torch | |
class T5RegressionModel(PreTrainedModel): | |
config_class = T5Config | |
def __init__(self, config, d_model=None): | |
super().__init__(config) | |
self.encoder = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50") | |
hidden_dim = d_model if d_model is not None else config.d_model | |
self.regression_head = nn.Linear(hidden_dim, 1) | |
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): | |
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
hidden_states = encoder_outputs.last_hidden_state | |
pooled_output = hidden_states[:, -1, :] | |
logits = self.regression_head(pooled_output).squeeze(-1) | |
loss = None | |
if labels is not None: | |
labels = labels.float() | |
loss = nn.MSELoss()(logits, labels) | |
return {"loss": loss, "logits": logits} |