File size: 4,856 Bytes
26c2f02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import torch
torch.manual_seed(1024)
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_hformer import HformerConfig
from .qformer_src import BertConfig, BertLMHeadModel
from transformers import BertTokenizerFast as BertTokenizer
from .configuration_projector import ProjectorConfig
from .modeling_projector import ProjectorModel
import torch.nn.functional as F
from transformers.activations import ACT2FN
class LayerNorm(nn.LayerNorm):
def forward(self, x: torch.Tensor):
ret = super().forward(x)
return ret
class HformerModel(PreTrainedModel):
_auto_class = 'AutoModel'
config_class = HformerConfig
base_model_prefix = 'model'
supports_gradient_checkpointing = False
def __init__(self, config) -> None:
super().__init__(config)
self.gradient_checkpointing = False
vision_width = config.visual_hidden_size
num_query_token = config.num_query_token
bert = config.bert
llm_hidden_size = config.llm_hidden_size
cross_attention_freq = config.cross_attention_freq
qformer_pth = config.qformer_pth
encoder_config = BertConfig.from_pretrained(bert)
encoder_config.encoder_width = vision_width
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
encoder_config.num_hidden_layers = 12
Qformer = BertLMHeadModel.from_pretrained(
bert, config=encoder_config
)
remove_text = False
if remove_text:
Qformer.cls = None
Qformer.bert.embeddings.word_embeddings = None
Qformer.bert.embeddings.position_embeddings = None
for layer in Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
self.Qformer = Qformer
self.query_tokens = query_tokens
self.llm_proj = nn.Linear(encoder_config.hidden_size, llm_hidden_size, bias=config.bias)
self.ln_vision = LayerNorm(encoder_config.encoder_width)
self.ln_llava = LayerNorm(encoder_config.encoder_width)
tokenizer = BertTokenizer.from_pretrained(bert, truncation_side='right')
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
self.Qformer.resize_token_embeddings(len(tokenizer))
if qformer_pth is not None:
pretrained_state_dict = torch.load(qformer_pth, map_location='cpu')['model']
print(f'Load Qformer from {qformer_pth}')
self.load_state_dict(pretrained_state_dict, strict=False)
print('Done.')
projector_config = ProjectorConfig(
visual_hidden_size = config.visual_hidden_size,
llm_hidden_size = config.llm_hidden_size,
projector_depth = 2)
self.connector = ProjectorModel(projector_config)
modules = [
nn.Linear(config.llm_hidden_size, config.llm_hidden_size//4, bias=False),
ACT2FN['gelu'],
nn.Linear(config.llm_hidden_size//4, config.llm_hidden_size, bias=False)
]
self.ffn = nn.Sequential(*modules)
def enable_input_require_grads(self):
def make_inputs_require_grad(module, input, output):
if isinstance(output, tuple):
output[0].requires_grad_(True)
output[1].requires_grad_(True)
else:
output.requires_grad_(True)
self.Qformer.register_forward_hook(make_inputs_require_grad)
self.llm_proj.register_forward_hook(make_inputs_require_grad)
self.ln_vision.register_forward_hook(make_inputs_require_grad)
self.connector.register_forward_hook(make_inputs_require_grad)
self.ffn.register_forward_hook(make_inputs_require_grad)
def _set_gradient_checkpointing(self, module, value=False):
pass
def forward(self, x_):
if self.gradient_checkpointing and self.training:
print('Not support gradient checkpointing')
x = self.ln_vision(x_)
query_tokens = self.query_tokens.expand(x.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=x,
return_dict=True,
)
q_feat = self.llm_proj(query_output.last_hidden_state)
mlp_outputs = self.connector(x_)
mlp_feat = mlp_outputs
int_feat = mlp_feat + q_feat.mean(dim=1)[:,None]
out = int_feat + self.ffn(int_feat)
return out
|