File size: 4,856 Bytes
26c2f02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
torch.manual_seed(1024)

import torch.nn as nn
from transformers import PreTrainedModel

from .configuration_hformer import HformerConfig
from .qformer_src import BertConfig, BertLMHeadModel

from transformers import BertTokenizerFast as BertTokenizer

from .configuration_projector import ProjectorConfig
from .modeling_projector import ProjectorModel
import torch.nn.functional as F
from transformers.activations import ACT2FN


class LayerNorm(nn.LayerNorm):
    def forward(self, x: torch.Tensor):
        ret = super().forward(x)
        return ret

class HformerModel(PreTrainedModel):
    _auto_class = 'AutoModel'
    config_class = HformerConfig
    base_model_prefix = 'model'
    supports_gradient_checkpointing = False

    def __init__(self, config) -> None:
        super().__init__(config)
        self.gradient_checkpointing = False
        vision_width = config.visual_hidden_size
        num_query_token = config.num_query_token
        bert = config.bert
        llm_hidden_size = config.llm_hidden_size
        cross_attention_freq = config.cross_attention_freq
        qformer_pth = config.qformer_pth

        encoder_config = BertConfig.from_pretrained(bert)
        encoder_config.encoder_width = vision_width
        encoder_config.add_cross_attention = True
        encoder_config.cross_attention_freq = cross_attention_freq
        encoder_config.query_length = num_query_token
        encoder_config.num_hidden_layers = 12
        Qformer = BertLMHeadModel.from_pretrained(
            bert, config=encoder_config
        )
        remove_text = False
        if remove_text:
            Qformer.cls = None
            Qformer.bert.embeddings.word_embeddings = None
            Qformer.bert.embeddings.position_embeddings = None
            for layer in Qformer.bert.encoder.layer:
                layer.output = None
                layer.intermediate = None

        query_tokens = nn.Parameter(
            torch.zeros(1, num_query_token, encoder_config.hidden_size)
        )
        query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
        
        self.Qformer = Qformer
        self.query_tokens = query_tokens
        self.llm_proj = nn.Linear(encoder_config.hidden_size, llm_hidden_size, bias=config.bias)
        self.ln_vision = LayerNorm(encoder_config.encoder_width)
        self.ln_llava = LayerNorm(encoder_config.encoder_width)
        
        tokenizer = BertTokenizer.from_pretrained(bert, truncation_side='right')
        tokenizer.add_special_tokens({"bos_token": "[DEC]"})
        self.Qformer.resize_token_embeddings(len(tokenizer))

        if qformer_pth is not None:
            pretrained_state_dict = torch.load(qformer_pth, map_location='cpu')['model']
            print(f'Load Qformer from {qformer_pth}')
            self.load_state_dict(pretrained_state_dict, strict=False)
            print('Done.')

        projector_config = ProjectorConfig(
            visual_hidden_size = config.visual_hidden_size,
            llm_hidden_size = config.llm_hidden_size,
            projector_depth = 2)
        self.connector = ProjectorModel(projector_config)

        modules = [
                nn.Linear(config.llm_hidden_size, config.llm_hidden_size//4, bias=False),
                ACT2FN['gelu'],
                nn.Linear(config.llm_hidden_size//4, config.llm_hidden_size, bias=False)
                ]
        self.ffn = nn.Sequential(*modules)

    def enable_input_require_grads(self):
        def make_inputs_require_grad(module, input, output):
            if isinstance(output, tuple):
                output[0].requires_grad_(True)
                output[1].requires_grad_(True)
            else:
                output.requires_grad_(True)

        self.Qformer.register_forward_hook(make_inputs_require_grad)
        self.llm_proj.register_forward_hook(make_inputs_require_grad)
        self.ln_vision.register_forward_hook(make_inputs_require_grad)
        self.connector.register_forward_hook(make_inputs_require_grad)
        self.ffn.register_forward_hook(make_inputs_require_grad)

    def _set_gradient_checkpointing(self, module, value=False):
        pass

    def forward(self, x_):
        if self.gradient_checkpointing and self.training:
            print('Not support gradient checkpointing')
        x = self.ln_vision(x_)
        query_tokens = self.query_tokens.expand(x.shape[0], -1, -1)
        query_output = self.Qformer.bert(
                query_embeds=query_tokens,
                encoder_hidden_states=x,
                return_dict=True,
        )
        
        q_feat = self.llm_proj(query_output.last_hidden_state)
        
        mlp_outputs = self.connector(x_)
        mlp_feat = mlp_outputs

        int_feat = mlp_feat + q_feat.mean(dim=1)[:,None]
        out = int_feat + self.ffn(int_feat)

        return out