File size: 2,511 Bytes
ce7cacf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
from .configuration_clip_camembert import CLIPTextCamembertConfig
from transformers import (
CamembertModel,
CLIPTextModelWithProjection,
)
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
import torch
from torch import nn
from typing import Any, Optional, Tuple, Union
class CLIPTextCamembertModelWithProjection(CLIPTextModelWithProjection):
config_class = CLIPTextCamembertConfig
def __init__(self, config: CLIPTextCamembertConfig):
super().__init__(config)
self.text_model = CamembertModel(config)
self.text_projection = nn.Linear(
config.hidden_size, config.projection_dim, bias=False
)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CLIPTextModelOutput]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = text_outputs[1]
text_embeds = self.text_projection(pooled_output)
if not return_dict:
outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
return tuple(output for output in outputs if output is not None)
return CLIPTextModelOutput(
text_embeds=text_embeds,
last_hidden_state=text_outputs.last_hidden_state,
hidden_states=text_outputs.hidden_states,
attentions=text_outputs.attentions,
)
def converter_weight(
self, path_model="airesearch/wangchanberta-base-att-spm-uncased"
):
r"""
converter weight from airesearch/wangchanberta-base-att-spm-uncased
"""
pretrained_state_dict = CamembertModel.from_pretrained(path_model).state_dict()
# Load the new state dictionary into the custom model
self.text_model.load_state_dict(pretrained_state_dict) |