from transformers import T5EncoderModel, T5Config, PreTrainedModel import torch.nn as nn import torch class T5RegressionModel(PreTrainedModel): config_class = T5Config def __init__(self, config, d_model=None): super().__init__(config) self.encoder = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50") hidden_dim = d_model if d_model is not None else config.d_model self.regression_head = nn.Linear(hidden_dim, 1) def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) hidden_states = encoder_outputs.last_hidden_state pooled_output = hidden_states[:, -1, :] logits = self.regression_head(pooled_output).squeeze(-1) loss = None if labels is not None: labels = labels.float() loss = nn.MSELoss()(logits, labels) return {"loss": loss, "logits": logits}