dle666 commited on
Commit
0d06432
·
verified ·
1 Parent(s): 2029687

Upload modeling_rcot2b_chat.py

Browse files
Files changed (1) hide show
  1. modeling_rcot2b_chat.py +392 -0
modeling_rcot2b_chat.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ import warnings
7
+ from typing import Any, List, Optional, Tuple, Union
8
+
9
+ import torch.utils.checkpoint
10
+ import transformers
11
+ from torch import nn
12
+ from torch.nn import CrossEntropyLoss
13
+ from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
14
+ LlamaTokenizer)
15
+ from transformers.modeling_outputs import CausalLMOutputWithPast
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import ModelOutput, logging
18
+
19
+ from .configuration_internvl_chat import InternVLChatConfig
20
+ from .conversation import get_conv_template
21
+ from .modeling_intern_vit import InternVisionModel
22
+ from .modeling_internlm2 import InternLM2ForCausalLM
23
+ import pdb
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ def version_cmp(v1, v2, op='eq'):
29
+ import operator
30
+
31
+ from packaging import version
32
+ op_func = getattr(operator, op)
33
+ return op_func(version.parse(v1), version.parse(v2))
34
+
35
+
36
+ class RCoTChatModel2B(PreTrainedModel):
37
+ config_class = InternVLChatConfig
38
+ main_input_name = 'pixel_values'
39
+ _supports_flash_attn_2 = True
40
+ _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'InternLM2DecoderLayer']
41
+
42
+ def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
43
+ super().__init__(config)
44
+
45
+ assert version_cmp(transformers.__version__, '4.36.2', 'ge')
46
+ image_size = config.force_image_size or config.vision_config.image_size
47
+ patch_size = config.vision_config.patch_size
48
+ self.patch_size = patch_size
49
+ self.select_layer = config.select_layer
50
+ self.template = config.template
51
+ self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
52
+ self.downsample_ratio = config.downsample_ratio
53
+ self.ps_version = config.ps_version
54
+
55
+ logger.info(f'num_image_token: {self.num_image_token}')
56
+ logger.info(f'ps_version: {self.ps_version}')
57
+ if vision_model is not None:
58
+ self.vision_model = vision_model
59
+ else:
60
+ self.vision_model = InternVisionModel(config.vision_config)
61
+ if language_model is not None:
62
+ self.language_model = language_model
63
+ else:
64
+ if config.llm_config.architectures[0] == 'LlamaForCausalLM':
65
+ self.language_model = LlamaForCausalLM(config.llm_config)
66
+ elif config.llm_config.architectures[0] == 'InternLM2ForCausalLM':
67
+ self.language_model = InternLM2ForCausalLM(config.llm_config)
68
+ else:
69
+ raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
70
+
71
+ vit_hidden_size = config.vision_config.hidden_size
72
+ llm_hidden_size = config.llm_config.hidden_size
73
+
74
+ self.mlp1 = nn.Sequential(
75
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
76
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
77
+ nn.GELU(),
78
+ nn.Linear(llm_hidden_size, llm_hidden_size)
79
+ )
80
+
81
+ self.img_context_token_id = None
82
+ self.conv_template = get_conv_template(self.template)
83
+ self.system_message = self.conv_template.system_message
84
+
85
+ def forward(
86
+ self,
87
+ pixel_values: torch.FloatTensor,
88
+ input_ids: torch.LongTensor = None,
89
+ attention_mask: Optional[torch.Tensor] = None,
90
+ position_ids: Optional[torch.LongTensor] = None,
91
+ image_flags: Optional[torch.LongTensor] = None,
92
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
93
+ labels: Optional[torch.LongTensor] = None,
94
+ use_cache: Optional[bool] = None,
95
+ output_attentions: Optional[bool] = None,
96
+ output_hidden_states: Optional[bool] = None,
97
+ return_dict: Optional[bool] = None,
98
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
99
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
100
+
101
+ image_flags = image_flags.squeeze(-1)
102
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
103
+
104
+ vit_embeds = self.extract_feature(pixel_values)
105
+ vit_embeds = vit_embeds[image_flags == 1]
106
+ vit_batch_size = pixel_values.shape[0]
107
+
108
+ B, N, C = input_embeds.shape
109
+ input_embeds = input_embeds.reshape(B * N, C)
110
+
111
+ if torch.distributed.get_rank() == 0:
112
+ print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
113
+
114
+ input_ids = input_ids.reshape(B * N)
115
+ selected = (input_ids == self.img_context_token_id)
116
+ try:
117
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
118
+ except Exception as e:
119
+ vit_embeds = vit_embeds.reshape(-1, C)
120
+ print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
121
+ f'vit_embeds.shape={vit_embeds.shape}')
122
+ n_token = selected.sum()
123
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
124
+
125
+ input_embeds = input_embeds.reshape(B, N, C)
126
+
127
+ outputs = self.language_model(
128
+ inputs_embeds=input_embeds,
129
+ attention_mask=attention_mask,
130
+ position_ids=position_ids,
131
+ past_key_values=past_key_values,
132
+ use_cache=use_cache,
133
+ output_attentions=output_attentions,
134
+ output_hidden_states=output_hidden_states,
135
+ return_dict=return_dict,
136
+ )
137
+ logits = outputs.logits
138
+
139
+ loss = None
140
+ if labels is not None:
141
+ # Shift so that tokens < n predict n
142
+ shift_logits = logits[..., :-1, :].contiguous()
143
+ shift_labels = labels[..., 1:].contiguous()
144
+ # Flatten the tokens
145
+ loss_fct = CrossEntropyLoss()
146
+ shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
147
+ shift_labels = shift_labels.view(-1)
148
+ # Enable model parallelism
149
+ shift_labels = shift_labels.to(shift_logits.device)
150
+ loss = loss_fct(shift_logits, shift_labels)
151
+
152
+ if not return_dict:
153
+ output = (logits,) + outputs[1:]
154
+ return (loss,) + output if loss is not None else output
155
+
156
+ return CausalLMOutputWithPast(
157
+ loss=loss,
158
+ logits=logits,
159
+ past_key_values=outputs.past_key_values,
160
+ hidden_states=outputs.hidden_states,
161
+ attentions=outputs.attentions,
162
+ )
163
+
164
+ def pixel_shuffle(self, x, scale_factor=0.5):
165
+ n, w, h, c = x.size()
166
+ # N, W, H, C --> N, W, H * scale, C // scale
167
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
168
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
169
+ x = x.permute(0, 2, 1, 3).contiguous()
170
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
171
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
172
+ int(c / (scale_factor * scale_factor)))
173
+ if self.ps_version == 'v1':
174
+ warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
175
+ 'which results in a transposed image.')
176
+ else:
177
+ x = x.permute(0, 2, 1, 3).contiguous()
178
+ return x
179
+
180
+ def extract_feature(self, pixel_values):
181
+ if self.select_layer == -1:
182
+ vit_embeds = self.vision_model(
183
+ pixel_values=pixel_values,
184
+ output_hidden_states=False,
185
+ return_dict=True).last_hidden_state
186
+ else:
187
+ vit_embeds = self.vision_model(
188
+ pixel_values=pixel_values,
189
+ output_hidden_states=True,
190
+ return_dict=True).hidden_states[self.select_layer]
191
+ vit_embeds = vit_embeds[:, 1:, :]
192
+
193
+ h = w = int(vit_embeds.shape[1] ** 0.5)
194
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
195
+ vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
196
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
197
+ vit_embeds = self.mlp1(vit_embeds)
198
+ return vit_embeds
199
+
200
+ def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
201
+ history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
202
+ IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
203
+ if history is not None or return_history:
204
+ print('Now multi-turn chat is not supported in batch_chat.')
205
+ raise NotImplementedError
206
+
207
+ if image_counts is not None:
208
+ num_patches_list = image_counts
209
+ print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
210
+
211
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
212
+ self.img_context_token_id = img_context_token_id
213
+
214
+ if verbose and pixel_values is not None:
215
+ image_bs = pixel_values.shape[0]
216
+ print(f'dynamic ViT batch size: {image_bs}')
217
+
218
+ queries = []
219
+ for idx, num_patches in enumerate(num_patches_list):
220
+ question = questions[idx]
221
+ if pixel_values is not None and '<image>' not in question:
222
+ question = '<image>\n' + question
223
+ template = get_conv_template(self.template)
224
+ template.append_message(template.roles[0], question)
225
+ template.append_message(template.roles[1], None)
226
+ query = template.get_prompt()
227
+
228
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
229
+ query = query.replace('<image>', image_tokens, 1)
230
+ queries.append(query)
231
+
232
+ tokenizer.padding_side = 'left'
233
+ model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
234
+ input_ids = model_inputs['input_ids'].cuda()
235
+ attention_mask = model_inputs['attention_mask'].cuda()
236
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
237
+ generation_config['eos_token_id'] = eos_token_id
238
+ generation_output = self.generate(
239
+ pixel_values=pixel_values,
240
+ input_ids=input_ids,
241
+ attention_mask=attention_mask,
242
+ **generation_config
243
+ )
244
+ responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
245
+ responses = [response.split(template.sep)[0].strip() for response in responses]
246
+ return responses
247
+
248
+ def chat(self, tokenizer, pixel_values, target_aspect_ratio, question, generation_config, use_scm=False, history=None, return_history=False,
249
+ num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
250
+ verbose=False):
251
+
252
+ if history is None and pixel_values is not None and '<image>' not in question:
253
+ question = '<image>\n' + question
254
+
255
+ if num_patches_list is None:
256
+ num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
257
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
258
+
259
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
260
+ self.img_context_token_id = img_context_token_id
261
+
262
+ template = get_conv_template(self.template)
263
+ template.system_message = self.system_message
264
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
265
+
266
+ history = [] if history is None else history
267
+ for (old_question, old_answer) in history:
268
+ template.append_message(template.roles[0], old_question)
269
+ template.append_message(template.roles[1], old_answer)
270
+ template.append_message(template.roles[0], question)
271
+ template.append_message(template.roles[1], None)
272
+ query = template.get_prompt()
273
+
274
+ if verbose and pixel_values is not None:
275
+ image_bs = pixel_values.shape[0]
276
+ print(f'dynamic ViT batch size: {image_bs}')
277
+
278
+ for num_patches in num_patches_list:
279
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
280
+ query = query.replace('<image>', image_tokens, 1)
281
+
282
+ model_inputs = tokenizer(query, return_tensors='pt')
283
+ input_ids = model_inputs['input_ids'].to(self.device)
284
+ attention_mask = model_inputs['attention_mask'].to(self.device)
285
+ generation_config['eos_token_id'] = eos_token_id
286
+ generation_output = self.generate(
287
+ pixel_values=pixel_values,
288
+ input_ids=input_ids,
289
+ attention_mask=attention_mask,
290
+ target_aspect_ratio=target_aspect_ratio,
291
+ use_scm=use_scm,
292
+ **generation_config
293
+ )
294
+ response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
295
+ response = response.split(template.sep)[0].strip()
296
+ history.append((question, response))
297
+ if return_history:
298
+ return response, history
299
+ else:
300
+ query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
301
+ query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
302
+ if verbose:
303
+ print(query_to_print, response)
304
+ return response
305
+
306
+ @torch.no_grad()
307
+ def generate(
308
+ self,
309
+ pixel_values: Optional[torch.FloatTensor] = None,
310
+ input_ids: Optional[torch.FloatTensor] = None,
311
+ attention_mask: Optional[torch.LongTensor] = None,
312
+ target_aspect_ratio: Optional[torch.LongTensor] = None,
313
+ visual_features: Optional[torch.FloatTensor] = None,
314
+ generation_config: Optional[GenerationConfig] = None,
315
+ output_hidden_states: Optional[bool] = None,
316
+ return_dict: Optional[bool] = None,
317
+ use_scm: Optional[bool] = False,
318
+ **generate_kwargs,
319
+ ) -> torch.LongTensor:
320
+
321
+ assert self.img_context_token_id is not None
322
+ if pixel_values is not None:
323
+ if visual_features is not None:
324
+ vit_embeds = visual_features
325
+ else:
326
+ vit_embeds = self.extract_feature(pixel_values)
327
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
328
+ B, N, C = input_embeds.shape
329
+ input_embeds = input_embeds.reshape(B * N, C)
330
+
331
+ input_ids = input_ids.reshape(B * N)
332
+ selected = (input_ids == self.img_context_token_id)
333
+ assert selected.sum() != 0
334
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
335
+
336
+ input_embeds = input_embeds.reshape(B, N, C)
337
+ else:
338
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
339
+
340
+ if use_scm:
341
+ self.language_model.model.img_idx = torch.where(selected==True)
342
+ self.language_model.model.high_token = target_aspect_ratio[0]*target_aspect_ratio[1] * self.num_image_token
343
+ batch_size, seq_length = input_embeds.shape[:2]
344
+ device = input_embeds.device
345
+ position_ids = torch.arange(
346
+ 0, seq_length, dtype=torch.long, device=device
347
+ )
348
+ position_ids = position_ids.unsqueeze(0)
349
+ new_attention_mask = self.language_model.model._prepare_decoder_attention_mask(
350
+ attention_mask, (batch_size, seq_length), input_embeds, 0
351
+ )
352
+ tmp_layer_outputs = self.language_model.model.layers[0](
353
+ input_embeds,
354
+ attention_mask=new_attention_mask,
355
+ position_ids=position_ids,
356
+ past_key_value=None,
357
+ output_attentions=False,
358
+ use_cache=False,
359
+ )
360
+
361
+ tmp_layer_outputs2 = self.language_model.model.layers[1](
362
+ tmp_layer_outputs[0],
363
+ attention_mask=new_attention_mask,
364
+ position_ids=position_ids,
365
+ past_key_value=None,
366
+ output_attentions=True,
367
+ use_cache=False,
368
+ )
369
+
370
+ tmp_attn = tmp_layer_outputs2[1]
371
+ tmp_attn = tmp_attn[:,:,self.language_model.model.img_idx[0][0]+self.language_model.model.high_token:,self.language_model.model.img_idx[0][0]:self.language_model.model.img_idx[0][0]+self.language_model.model.high_token]
372
+ tmp_attn = tmp_attn.mean(2)
373
+ tmp_idx = tmp_attn.mean(1).topk(int(tmp_attn.shape[-1] * 0.5)).indices + self.language_model.model.img_idx[0][0]
374
+ top_attention_rank_index = tmp_idx.sort().values[0]
375
+ device = input_embeds.device
376
+ top_attention_rank_index = torch.cat((torch.arange(self.language_model.model.img_idx[0][0],device=device), top_attention_rank_index, torch.arange(self.language_model.model.img_idx[0][0]+self.language_model.model.high_token+1, input_embeds.shape[1],device=device)))
377
+ input_embeds = input_embeds[:,top_attention_rank_index]
378
+ attention_mask = torch.ones(
379
+ (input_embeds.shape[0], input_embeds.shape[1]), dtype=torch.bool, device=device
380
+ )
381
+
382
+ outputs = self.language_model.generate(
383
+ inputs_embeds=input_embeds,
384
+ attention_mask=attention_mask,
385
+ generation_config=generation_config,
386
+ output_hidden_states=output_hidden_states,
387
+ return_dict=return_dict,
388
+ use_cache=True,
389
+ **generate_kwargs,
390
+ )
391
+
392
+ return outputs