File size: 2,511 Bytes
ce7cacf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from .configuration_clip_camembert import CLIPTextCamembertConfig
from transformers import (
    CamembertModel,
    CLIPTextModelWithProjection,
)
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
import torch
from torch import nn
from typing import Any, Optional, Tuple, Union


class CLIPTextCamembertModelWithProjection(CLIPTextModelWithProjection):
    config_class = CLIPTextCamembertConfig

    def __init__(self, config: CLIPTextCamembertConfig):
        super().__init__(config)

        self.text_model = CamembertModel(config)

        self.text_projection = nn.Linear(
            config.hidden_size, config.projection_dim, bias=False
        )
        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CLIPTextModelOutput]:
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        text_outputs = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        pooled_output = text_outputs[1]

        text_embeds = self.text_projection(pooled_output)

        if not return_dict:
            outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
            return tuple(output for output in outputs if output is not None)

        return CLIPTextModelOutput(
            text_embeds=text_embeds,
            last_hidden_state=text_outputs.last_hidden_state,
            hidden_states=text_outputs.hidden_states,
            attentions=text_outputs.attentions,
        )

    def converter_weight(
        self, path_model="airesearch/wangchanberta-base-att-spm-uncased"
    ):
        r"""
        converter weight from airesearch/wangchanberta-base-att-spm-uncased
        """
        pretrained_state_dict = CamembertModel.from_pretrained(path_model).state_dict()
        # Load the new state dictionary into the custom model
        self.text_model.load_state_dict(pretrained_state_dict)