File size: 2,582 Bytes
194b80a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers import CLIPTextConfig, CLIPTextModel
from transformers.modeling_outputs import MaskedLMOutput
from transformers.models.clip.modeling_clip import CLIPPreTrainedModel
from transformers.models.roberta.modeling_roberta import RobertaLMHead
class CLIPTextModelForMaskedLM(CLIPPreTrainedModel):
config_class = CLIPTextConfig
def __init__(self, config: CLIPTextConfig):
super().__init__(config)
self.clip_text_model = CLIPTextModel(config)
self.lm_head = RobertaLMHead(config)
self.post_init()
def get_input_embeddings(self):
return self.clip_text_model.text_model.embeddings.token_embedding
def set_input_embeddings(self, value):
self.clip_text_model.text_model.embeddings.token_embedding = value
def get_output_embeddings(self):
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.clip_text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
mlm_loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
mlm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((mlm_loss,) + output) if mlm_loss is not None else output
return MaskedLMOutput(
loss=mlm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
|