Enable intel Gaudi platform

#11
Files changed (1) hide show
  1. modeling_cogvlm.py +846 -862
modeling_cogvlm.py CHANGED
@@ -1,862 +1,846 @@
1
- """largely copy from llama and adapt for cogvlm"""
2
- import warnings
3
- from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any
4
-
5
- import math
6
- import torch
7
- from torch import nn
8
- from torch.nn import CrossEntropyLoss
9
- from torchvision import transforms
10
- from einops import rearrange
11
- import transformers
12
- from transformers import PreTrainedModel, PreTrainedTokenizer
13
- from transformers.utils.logging import get_logger
14
- from transformers.activations import ACT2FN
15
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
- from torchvision.transforms import Lambda
17
- from torchvision.transforms._transforms_video import NormalizeVideo, CenterCropVideo
18
- from pytorchvideo.transforms import ShortSideScale
19
- from .configuration_cogvlm import CogVLMConfig
20
- from .util import FastRotaryEmbedding
21
- from .visual import EVA2CLIPModel
22
-
23
- if TYPE_CHECKING:
24
- from transformers.utils import ModelOutput
25
-
26
- logger = get_logger(__name__)
27
-
28
- LANGUAGE_TOKEN_TYPE = 0
29
- VISION_TOKEN_TYPE = 1
30
-
31
- ALL_CACHE_NAMES = [
32
- "past_key_values", # default
33
- "cache_params", # mamba-based models
34
- "state", # rwkv
35
- "mems", # xlnet
36
- "past_buckets_states", # reformer
37
- ]
38
-
39
- # Copied from transformers.models.bart.modeling_bart._make_causal_mask
40
- def _make_causal_mask(
41
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
42
- ):
43
- """
44
- Make causal mask used for bi-directional self-attention.
45
- """
46
- bsz, tgt_len = input_ids_shape
47
- mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
48
- mask_cond = torch.arange(mask.size(-1), device=device)
49
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
50
- mask = mask.to(dtype)
51
-
52
- if past_key_values_length > 0:
53
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
54
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
55
-
56
-
57
- # Copied from transformers.models.bart.modeling_bart._expand_mask
58
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
59
- """
60
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
61
- """
62
- bsz, src_len = mask.size()
63
- tgt_len = tgt_len if tgt_len is not None else src_len
64
-
65
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
66
-
67
- inverted_mask = 1.0 - expanded_mask
68
-
69
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
70
-
71
-
72
- class RMSNorm(nn.Module):
73
- def __init__(self, hidden_size, eps=1e-5):
74
- super().__init__()
75
- self.weight = nn.Parameter(torch.ones(hidden_size))
76
- self.variance_epsilon = eps
77
-
78
- def forward(self, hidden_states):
79
- input_dtype = hidden_states.dtype
80
- hidden_states = hidden_states.to(torch.float32)
81
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
82
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
83
- return (self.weight * hidden_states).to(input_dtype)
84
-
85
-
86
- class MLP(nn.Module):
87
- def __init__(self, config):
88
- super().__init__()
89
- self.hidden_size = config.hidden_size
90
- self.intermediate_size = config.intermediate_size
91
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
92
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
93
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
94
- self.act_fn = ACT2FN[config.hidden_act]
95
-
96
- def forward(self, x):
97
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
98
- return down_proj
99
-
100
-
101
- def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
102
- vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
103
- vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (
104
- token_type_ids[:, 1:] == VISION_TOKEN_TYPE)
105
- language_token_mask = ~vision_token_mask
106
- return vision_token_mask, language_token_mask
107
-
108
-
109
- class VisionExpertMLP(nn.Module):
110
- def __init__(self, config):
111
- super().__init__()
112
- self.language_mlp = MLP(config)
113
- # self.vision_mlp = MLP(config)
114
-
115
- def forward(self, hidden_states: "torch.Tensor(B, L, D)", token_type_ids: "torch.LongTensor(B, L)"):
116
- # output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device)
117
- # vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
118
- # output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
119
- # output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
120
-
121
- output = self.language_mlp(hidden_states)
122
- return output
123
-
124
-
125
- def attention_fn(
126
- query_layer: "torch.tensor(B, H, L, HD)",
127
- key_layer: "torch.tensor(B, H, L, HD)",
128
- value_layer: "torch.tensor(B, H, L, HD)",
129
- attention_mask: "torch.tensor(B, H, L, HD)",
130
- *,
131
- scaling_attention_score: bool = True,
132
- attention_dropout: nn.Module = None
133
- ):
134
- attention_mask_bool = (attention_mask == 0)
135
- is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
136
- is_full = (attention_mask_bool > 0).all()
137
- if not (int(torch.__version__.split('.')[0]) >= 2):
138
- warnings.warn("It's recommended to use torch2.0 or higher.")
139
- if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
140
- dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
141
- return torch.nn.functional.scaled_dot_product_attention(
142
- query_layer, key_layer, value_layer,
143
- attn_mask=None,
144
- dropout_p=dropout_p,
145
- is_causal=not is_full
146
- )
147
- else:
148
- if scaling_attention_score:
149
- query_layer = query_layer / math.sqrt(query_layer.shape[-1])
150
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
151
- attention_scores = attention_scores + attention_mask
152
- attention_scores = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
153
- if attention_dropout is not None:
154
- attention_scores = attention_dropout(attention_scores)
155
- context_layer = torch.matmul(attention_scores, value_layer)
156
- return context_layer
157
-
158
-
159
- class VisionExpertAttention(nn.Module):
160
- def __init__(self, config):
161
- super().__init__()
162
- self.config = config
163
- self.hidden_size = config.hidden_size
164
- self.num_attention_heads = config.num_attention_heads
165
- self.num_multi_query_heads = config.num_multi_query_heads
166
- self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
167
- self.stride = [self.num_attention_heads, self.num_multi_query_heads, self.num_multi_query_heads]
168
- self.qkv_size = self.hidden_size + self.hidden_size_per_attention_head * self.num_multi_query_heads * 2
169
- self.head_dim = self.hidden_size // self.num_attention_heads
170
- self.max_position_embeddings = config.max_position_embeddings
171
- self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False, base=500000)
172
- # self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.qkv_size, bias=True)
173
- # self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
174
- self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.qkv_size, bias=False)
175
- self.language_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
176
-
177
- def _transpose_for_scores(self, tensor):
178
- """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
179
- new_tensor_shape = tensor.size()[:-1] + \
180
- (-1, # flexible for multi-query
181
- self.hidden_size_per_attention_head)
182
- tensor = tensor.view(*new_tensor_shape)
183
- return tensor.permute(0, 2, 1, 3)
184
-
185
- def forward(
186
- self,
187
- hidden_states: torch.Tensor,
188
- token_type_ids: torch.LongTensor,
189
- position_ids: torch.LongTensor,
190
- attention_mask: Optional[torch.Tensor] = None,
191
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
192
- output_attentions: bool = False,
193
- use_cache: bool = False,
194
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
195
- bsz, q_len, _ = hidden_states.size()
196
- # vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
197
-
198
- shape = list(hidden_states.shape)
199
- shape[-1] = self.qkv_size
200
- # mixed_raw_layer = torch.empty(shape, dtype=hidden_states.dtype, device=hidden_states.device)
201
- # mixed_raw_layer[vision_token_mask] = self.vision_expert_query_key_value(hidden_states[vision_token_mask])
202
- # mixed_raw_layer[language_token_mask] = self.language_expert_query_key_value(hidden_states[language_token_mask])
203
- mixed_raw_layer = self.language_expert_query_key_value(hidden_states)
204
-
205
- # query_states, key_states, value_states = torch.split(mixed_raw_layer, self.hidden_size, dim=-1)
206
- factor = mixed_raw_layer.size()[-1] // sum(self.stride)
207
- query_states, key_states, value_states = torch.split(mixed_raw_layer, [factor * x for x in self.stride], dim=-1)
208
-
209
- query_states = self._transpose_for_scores(query_states) # B, H, L, HD
210
- key_states = self._transpose_for_scores(key_states) # B, H, L, HD
211
- value_states = self._transpose_for_scores(value_states) # B, H, L, HD
212
-
213
- kv_seq_len = key_states.shape[-2]
214
- if past_key_value is not None:
215
- kv_seq_len += past_key_value[0].shape[-2]
216
-
217
- query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids,
218
- max_seqlen=position_ids.max() + 1)
219
-
220
- if past_key_value is not None:
221
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
222
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
223
-
224
- past_key_value = (key_states, value_states) if use_cache else None
225
-
226
- key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads, -1,
227
- -1).contiguous().view(
228
- bsz, self.num_attention_heads, *key_states.shape[2:])
229
- value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads,
230
- -1,
231
- -1).contiguous().view(bsz, self.num_attention_heads,
232
- *value_states.shape[2:])
233
-
234
- context_layer = attention_fn(
235
- query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
236
- scaling_attention_score=True, attention_dropout=None)
237
- if context_layer.size() != (bsz, self.num_attention_heads, q_len, self.head_dim):
238
- raise ValueError(
239
- f"`attn_output` should be of size {(bsz, self.num_attention_heads, q_len, self.head_dim)}, but is"
240
- f" {context_layer.size()}"
241
- )
242
- context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
243
-
244
- # attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
245
- # attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
246
- # attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
247
-
248
- attn_output = self.language_expert_dense(context_layer)
249
-
250
- if output_attentions:
251
- warnings.warn("output_attentions is not implemented.")
252
-
253
- return attn_output, None, past_key_value
254
-
255
-
256
- class CogVLMDecoderLayer(nn.Module):
257
- def __init__(self, config):
258
- super().__init__()
259
- self.hidden_size = config.hidden_size
260
- self.self_attn = VisionExpertAttention(config=config)
261
- self.mlp = VisionExpertMLP(config)
262
- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
263
- self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
264
-
265
- def forward(
266
- self,
267
- hidden_states: torch.Tensor,
268
- token_type_ids: torch.LongTensor,
269
- position_ids: torch.LongTensor,
270
- attention_mask: Optional[torch.Tensor] = None,
271
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
272
- output_attentions: Optional[bool] = False,
273
- use_cache: Optional[bool] = False,
274
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
275
- residual = hidden_states
276
-
277
- hidden_states = self.input_layernorm(hidden_states)
278
-
279
- # Self Attention
280
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
281
- hidden_states=hidden_states,
282
- token_type_ids=token_type_ids,
283
- position_ids=position_ids,
284
- attention_mask=attention_mask,
285
- past_key_value=past_key_value,
286
- output_attentions=output_attentions,
287
- use_cache=use_cache,
288
- )
289
- hidden_states = residual + hidden_states
290
-
291
- # Fully Connected
292
- residual = hidden_states
293
- hidden_states = self.post_attention_layernorm(hidden_states)
294
- hidden_states = self.mlp(hidden_states, token_type_ids=token_type_ids)
295
- hidden_states = residual + hidden_states
296
-
297
- outputs = (hidden_states,)
298
-
299
- if output_attentions:
300
- outputs += (self_attn_weights,)
301
-
302
- if use_cache:
303
- outputs += (present_key_value,)
304
-
305
- return outputs # type: ignore
306
-
307
-
308
- class CogVLMPreTrainedModel(PreTrainedModel):
309
- config_class = CogVLMConfig
310
- base_model_prefix = "model"
311
- supports_gradient_checkpointing = False
312
- _no_split_modules = ["CogVLMDecoderLayer"]
313
- _skip_keys_device_placement = "past_key_values"
314
-
315
- def _init_weights(self, module):
316
- std = self.config.initializer_range
317
- if isinstance(module, nn.Linear):
318
- module.weight.data.normal_(mean=0.0, std=std)
319
- if module.bias is not None:
320
- module.bias.data.zero_()
321
- elif isinstance(module, nn.Embedding):
322
- module.weight.data.normal_(mean=0.0, std=std)
323
- if module.padding_idx is not None:
324
- module.weight.data[module.padding_idx].zero_()
325
-
326
-
327
- def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
328
- if images_list is None or len(images_list) == 0:
329
- return True
330
- for image_list in images_list:
331
- if len(image_list):
332
- return False
333
- return True
334
-
335
-
336
- def build_position_ids(x: "torch.BoolTensor(B, L)",
337
- attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)":
338
- if attention_mask is not None:
339
- tmp = x.clone()
340
- tmp[~(attention_mask.bool())] = -1
341
- else:
342
- tmp = x.clone()
343
- # image boi eoi token as LANGUAGE_TOKEN_TYPE
344
- is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
345
- is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
346
- is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
347
- is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
348
- is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
349
- tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
350
- # final position ids
351
- y = torch.zeros_like(x, dtype=torch.long)
352
- y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
353
- (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
354
- y = y.cumsum(dim=-1)
355
- return y
356
-
357
-
358
- class CogVLMVideoModel(CogVLMPreTrainedModel):
359
- def __init__(self, config):
360
- super().__init__(config)
361
- self.padding_idx = 128002
362
- self.vocab_size = config.vocab_size
363
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
364
- self.layers = nn.ModuleList([CogVLMDecoderLayer(config) for _ in range(config.num_hidden_layers)])
365
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
366
-
367
- self.vision = EVA2CLIPModel(config)
368
-
369
- self.gradient_checkpointing = False
370
- # Initialize weights and apply final processing
371
- self.post_init()
372
-
373
- def encode_images(self, images: List[List[torch.Tensor]], ) -> torch.Tensor:
374
- images_list, images = images, []
375
-
376
- images = []
377
- for image_list in images_list:
378
- for image in image_list:
379
- images.append(image)
380
-
381
- # images = torch.stack(images) # video images is already stacked
382
- images_features = self.vision(images[0])
383
- return images_features
384
-
385
- def forward(
386
- self,
387
- input_ids: torch.LongTensor = None,
388
- images: List[List[torch.Tensor]] = None,
389
- token_type_ids: Optional[torch.LongTensor] = None,
390
- attention_mask: Optional[torch.Tensor] = None,
391
- position_ids: Optional[torch.LongTensor] = None,
392
- past_key_values: Optional[List[torch.FloatTensor]] = None,
393
- inputs_embeds: Optional[torch.FloatTensor] = None,
394
- use_cache: Optional[bool] = None,
395
- output_attentions: Optional[bool] = None,
396
- output_hidden_states: Optional[bool] = None,
397
- return_dict: Optional[bool] = None,
398
- ) -> Union[Tuple, BaseModelOutputWithPast]:
399
- """take care of image_encode, token_type_ids, position_ids and (attention_mask = None is fine)"""
400
-
401
- if past_key_values is not None:
402
- pass # generate mode with past_key_values. the image features are already mapped
403
- else:
404
- # not allow for inputs_embeds, because we want to process image feature
405
- assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
406
- if not is_empty(images): # multi-modality
407
- assert token_type_ids is not None, f"multi-modality requires `token_type_ids`!"
408
- assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
409
- inputs_embeds = self.embed_tokens(input_ids)
410
- images_features = self.encode_images(images)
411
- images_features = rearrange(images_features, 'b n d -> (b n) d')
412
- images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
413
- inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
414
- else: # single-modality
415
- if token_type_ids is None:
416
- token_type_ids = torch.ones_like(input_ids, dtype=torch.long,
417
- device=input_ids.device) * LANGUAGE_TOKEN_TYPE
418
- assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
419
- inputs_embeds = self.embed_tokens(input_ids)
420
-
421
- if position_ids is None:
422
- position_ids = build_position_ids(token_type_ids, attention_mask)
423
- input_ids = None
424
- return self.llm_forward(
425
- input_ids=input_ids,
426
- token_type_ids=token_type_ids,
427
- attention_mask=attention_mask,
428
- position_ids=position_ids,
429
- past_key_values=past_key_values,
430
- inputs_embeds=inputs_embeds,
431
- use_cache=use_cache,
432
- output_attentions=output_attentions,
433
- output_hidden_states=output_hidden_states,
434
- return_dict=return_dict,
435
- )
436
-
437
- def llm_forward(
438
- self,
439
- input_ids: torch.LongTensor = None,
440
- token_type_ids: torch.LongTensor = None,
441
- attention_mask: Optional[torch.Tensor] = None,
442
- position_ids: Optional[torch.LongTensor] = None,
443
- past_key_values: Optional[List[torch.FloatTensor]] = None,
444
- inputs_embeds: Optional[torch.FloatTensor] = None,
445
- use_cache: Optional[bool] = None,
446
- output_attentions: Optional[bool] = None,
447
- output_hidden_states: Optional[bool] = None,
448
- return_dict: Optional[bool] = None,
449
- ) -> Union[Tuple, BaseModelOutputWithPast]:
450
- """largely copy from llama forward and adapt for cogvlm with `token_type_ids`"""
451
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
452
- output_hidden_states = (
453
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
454
- )
455
- use_cache = use_cache if use_cache is not None else self.config.use_cache
456
-
457
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
458
-
459
- # retrieve input_ids and inputs_embeds
460
- if input_ids is not None and inputs_embeds is not None:
461
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
462
- elif input_ids is not None:
463
- batch_size, seq_length = input_ids.shape
464
- elif inputs_embeds is not None:
465
- batch_size, seq_length, _ = inputs_embeds.shape
466
- else:
467
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
468
-
469
- seq_length_with_past = seq_length
470
- past_key_values_length = 0
471
-
472
- if past_key_values is not None:
473
- past_key_values_length = past_key_values[0][0].shape[2]
474
- seq_length_with_past = seq_length_with_past + past_key_values_length
475
-
476
- if position_ids is None:
477
- device = input_ids.device if input_ids is not None else inputs_embeds.device
478
- position_ids = torch.arange(
479
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
480
- )
481
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
482
- else:
483
- position_ids = position_ids.view(-1, seq_length).long()
484
-
485
- if inputs_embeds is None:
486
- inputs_embeds = self.embed_tokens(input_ids)
487
- # embed positions
488
- if attention_mask is None:
489
- attention_mask = torch.ones(
490
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
491
- )
492
- attention_mask = self._prepare_decoder_attention_mask(
493
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
494
- )
495
-
496
- hidden_states = inputs_embeds
497
-
498
- # decoder layers
499
- all_hidden_states = () if output_hidden_states else None
500
- all_self_attns = () if output_attentions else None
501
- next_decoder_cache = () if use_cache else None
502
-
503
- for idx, decoder_layer in enumerate(self.layers):
504
- if output_hidden_states:
505
- all_hidden_states += (hidden_states,)
506
-
507
- past_key_value = past_key_values[idx] if past_key_values is not None else None
508
- layer_outputs = decoder_layer(
509
- hidden_states,
510
- token_type_ids=token_type_ids,
511
- attention_mask=attention_mask,
512
- position_ids=position_ids,
513
- past_key_value=past_key_value,
514
- output_attentions=output_attentions,
515
- use_cache=use_cache,
516
- )
517
- hidden_states = layer_outputs[0]
518
-
519
- if use_cache:
520
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
521
-
522
- if output_attentions:
523
- all_self_attns += (layer_outputs[1],)
524
-
525
- hidden_states = self.norm(hidden_states)
526
-
527
- # add hidden states from the last decoder layer
528
- if output_hidden_states:
529
- all_hidden_states += (hidden_states,)
530
-
531
- next_cache = next_decoder_cache if use_cache else None
532
- if not return_dict:
533
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
534
- return BaseModelOutputWithPast(
535
- last_hidden_state=hidden_states,
536
- past_key_values=next_cache,
537
- hidden_states=all_hidden_states,
538
- attentions=all_self_attns,
539
- )
540
-
541
- def get_input_embeddings(self):
542
- return self.embed_tokens
543
-
544
- def set_input_embeddings(self, value):
545
- self.embed_tokens = value
546
-
547
- # noinspection PyMethodMayBeStatic
548
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
549
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
550
- # create causal mask
551
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
552
- combined_attention_mask = None
553
- if input_shape[-1] > 1:
554
- combined_attention_mask = _make_causal_mask(
555
- input_shape,
556
- inputs_embeds.dtype,
557
- device=inputs_embeds.device,
558
- past_key_values_length=past_key_values_length,
559
- )
560
-
561
- if attention_mask is not None:
562
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
563
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
564
- inputs_embeds.device
565
- )
566
- combined_attention_mask = (
567
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
568
- )
569
-
570
- return combined_attention_mask
571
-
572
-
573
- def _history_to_prompt(signal_type, history, query):
574
- if signal_type == 'base':
575
- return query
576
- elif signal_type == 'vqa':
577
- answer_format = 'Short answer:'
578
- elif signal_type == 'chat':
579
- answer_format = 'Answer:'
580
- else:
581
- assert False, f"Unknown signal type {signal_type}"
582
-
583
- prompt = ''
584
- for i, (old_query, response) in enumerate(history):
585
- prompt += 'Question: ' + old_query + " {} ".format(answer_format) + response + "\n"
586
- prompt += 'Question: {} {}'.format(query, answer_format)
587
- return prompt
588
-
589
-
590
- class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
591
- _auto_class = "AutoModelForCausalLM"
592
-
593
- def __init__(self, config):
594
- super().__init__(config)
595
- self.model = CogVLMVideoModel(config)
596
- self.vocab_size = config.vocab_size
597
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
598
- self.video_downsample = 1 # TODO: change this to config
599
-
600
- # Initialize weights and apply final processing
601
- self.post_init()
602
-
603
- def get_input_embeddings(self):
604
- return self.model.embed_tokens
605
-
606
- def set_input_embeddings(self, value):
607
- self.model.embed_tokens = value
608
-
609
- def get_output_embeddings(self):
610
- return self.lm_head
611
-
612
- def set_output_embeddings(self, new_embeddings):
613
- self.lm_head = new_embeddings
614
-
615
- def set_decoder(self, decoder):
616
- self.model = decoder
617
-
618
- def get_decoder(self):
619
- return self.model
620
-
621
- def forward(
622
- self,
623
- input_ids: torch.LongTensor = None,
624
- images: List[List[torch.Tensor]] = None,
625
- token_type_ids: Optional[torch.LongTensor] = None,
626
- attention_mask: Optional[torch.Tensor] = None,
627
- position_ids: Optional[torch.LongTensor] = None,
628
- past_key_values: Optional[List[torch.FloatTensor]] = None,
629
- inputs_embeds: Optional[torch.FloatTensor] = None,
630
- use_cache: Optional[bool] = None,
631
- output_attentions: Optional[bool] = None,
632
- output_hidden_states: Optional[bool] = None,
633
- return_dict: Optional[bool] = None,
634
- labels: Optional[torch.LongTensor] = None,
635
- ) -> Union[Tuple, CausalLMOutputWithPast]:
636
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
637
- output_hidden_states = (
638
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
639
- )
640
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
641
-
642
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
643
- outputs = self.model(
644
- input_ids=input_ids,
645
- images=images,
646
- token_type_ids=token_type_ids,
647
- attention_mask=attention_mask,
648
- position_ids=position_ids,
649
- past_key_values=past_key_values,
650
- inputs_embeds=inputs_embeds,
651
- use_cache=use_cache,
652
- output_attentions=output_attentions,
653
- output_hidden_states=output_hidden_states,
654
- return_dict=return_dict,
655
- )
656
-
657
- hidden_states = outputs[0]
658
- logits = self.lm_head(hidden_states)
659
- logits = logits.float()
660
-
661
- loss = None
662
- if labels is not None:
663
- # Shift so that tokens < n predict n
664
- shift_logits = logits[..., :-1, :].contiguous()
665
- shift_labels = labels[..., 1:].contiguous()
666
- # Flatten the tokens
667
- loss_fct = CrossEntropyLoss()
668
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
669
- shift_labels = shift_labels.view(-1)
670
- # Enable model parallelism
671
- shift_labels = shift_labels.to(shift_logits.device)
672
- loss = loss_fct(shift_logits, shift_labels)
673
-
674
- if not return_dict:
675
- output = (logits,) + outputs[1:]
676
- return (loss,) + output if loss is not None else output
677
-
678
- return CausalLMOutputWithPast(
679
- loss=loss,
680
- logits=logits,
681
- past_key_values=outputs.past_key_values,
682
- hidden_states=outputs.hidden_states,
683
- attentions=outputs.attentions,
684
- )
685
-
686
- def _prepare_attention_mask_for_generation(
687
- self,
688
- inputs: torch.Tensor,
689
- pad_token_id: Optional[int],
690
- eos_token_id: Optional[Union[int, List[int]]],
691
- ) -> torch.LongTensor:
692
- return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
693
-
694
- def prepare_inputs_for_generation(
695
- self, input_ids, token_type_ids, images=None, past_key_values=None, attention_mask=None, inputs_embeds=None,
696
- **kwargs
697
- ):
698
- # build position_ids if needed
699
- position_ids = kwargs.get("position_ids", None)
700
- if position_ids is None:
701
- position_ids = build_position_ids(token_type_ids, attention_mask)
702
-
703
- if past_key_values:
704
- input_ids = input_ids[:, -1:]
705
- token_type_ids = token_type_ids[:, -1:]
706
- position_ids = position_ids[:, -1:]
707
-
708
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
709
- if inputs_embeds is not None and past_key_values is None:
710
- model_inputs = {"inputs_embeds": inputs_embeds}
711
- else:
712
- model_inputs = {"input_ids": input_ids}
713
-
714
- model_inputs.update(
715
- {
716
- "token_type_ids": token_type_ids,
717
- "images": images,
718
- "position_ids": position_ids,
719
- "past_key_values": past_key_values,
720
- "use_cache": kwargs.get("use_cache"),
721
- "attention_mask": attention_mask,
722
- }
723
- )
724
- return model_inputs
725
-
726
- def _update_model_kwargs_for_generation(
727
- self,
728
- outputs: "ModelOutput",
729
- model_kwargs: Dict[str, Any],
730
- is_encoder_decoder: bool = False,
731
- standardize_cache_format: bool = False,
732
- ) -> Dict[str, Any]:
733
- # update past_key_values
734
- if transformers.__version__ < "4.44.0":
735
- cache_name, cache = self._extract_past_from_model_output(
736
- outputs, standardize_cache_format=standardize_cache_format
737
- )
738
- elif transformers.__version__ < "4.49.0":
739
- cache_name, cache = self._extract_past_from_model_output(
740
- outputs
741
- )
742
- else:
743
- for possible_cache_name in ALL_CACHE_NAMES:
744
- if possible_cache_name in outputs:
745
- if possible_cache_name in ("past_buckets_states", "mems"):
746
- cache_name = "past_key_values"
747
- else:
748
- cache_name = possible_cache_name
749
- cache = getattr(outputs, possible_cache_name)
750
- break
751
- model_kwargs[cache_name] = cache
752
-
753
- if getattr(outputs, "state", None) is not None:
754
- model_kwargs["state"] = outputs.state
755
-
756
- # update token_type_ids with last value
757
- if "token_type_ids" in model_kwargs:
758
- token_type_ids = model_kwargs["token_type_ids"]
759
- new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype,
760
- device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
761
- model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
762
-
763
- if not is_encoder_decoder:
764
- # update attention mask
765
- if "attention_mask" in model_kwargs:
766
- attention_mask = model_kwargs["attention_mask"]
767
- model_kwargs["attention_mask"] = torch.cat(
768
- [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
769
- )
770
- else:
771
- # update decoder attention mask
772
- if "decoder_attention_mask" in model_kwargs:
773
- decoder_attention_mask = model_kwargs["decoder_attention_mask"]
774
- model_kwargs["decoder_attention_mask"] = torch.cat(
775
- [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
776
- dim=-1,
777
- )
778
-
779
- return model_kwargs
780
-
781
- def _reorder_cache(self, past_key_values, beam_idx):
782
- reordered_past = ()
783
- for layer_past in past_key_values:
784
- reordered_past += (
785
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
786
- )
787
- return reordered_past
788
-
789
- def build_conversation_input_ids(
790
- self,
791
- tokenizer: "PreTrainedTokenizer",
792
- *,
793
- query: str,
794
- history: Optional[List[Tuple[str, str]]] = None,
795
- images: Optional[List["PIL.Image"]] = None,
796
- template_version: Optional[Literal["base", "chat", "vqa"]] = None,
797
- answer: str = None,
798
- ):
799
- image_size: int = self.config.vision_config['image_size']
800
- template_version = template_version or self.config.template_version
801
- assert images is None or len(images) <= 1, f"not support multi images by now."
802
- history = history or []
803
- text = _history_to_prompt(template_version, history, query)
804
- input_ids = [tokenizer.bos_token_id]
805
- token_type_ids = [LANGUAGE_TOKEN_TYPE]
806
- add_time_indices = True if template_version == 'chat' else False
807
- if images is not None and len(images) == 1:
808
- # vision
809
- transform = transforms.Compose(
810
- [
811
- # UniformTemporalSubsample(num_frames),
812
- Lambda(lambda x: x / 255.0),
813
- NormalizeVideo(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
814
- ShortSideScale(size=image_size),
815
- CenterCropVideo(image_size),
816
- # RandomHorizontalFlipVideo(p=0.5),
817
- ]
818
- )
819
- images = [transform(images[0]).transpose(0, 1)] # (T, C, H, W)
820
- num_eois = len(images[0])
821
- tokenizer.pad_token_id = 128002
822
- if not add_time_indices:
823
- vision_token_num = (64 + 2) * num_eois
824
- input_ids += [tokenizer.pad_token_id] * vision_token_num # add spetial token
825
- token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
826
- else:
827
- video_ids, video_type_ids = [], []
828
- sing_vision_token_num = (64 + 2)
829
- for _time_idx in range(num_eois):
830
- video_ids += [tokenizer.pad_token_id] * sing_vision_token_num
831
- video_type_ids += [VISION_TOKEN_TYPE] * sing_vision_token_num
832
- # add time indices
833
- time_indices = tokenizer.encode(str(_time_idx), add_special_tokens=False)
834
- video_ids += time_indices
835
- video_type_ids += [LANGUAGE_TOKEN_TYPE] * len(time_indices)
836
- # llama3 adapt for cogvlm
837
- input_ids += video_ids
838
- token_type_ids += video_type_ids
839
-
840
- text_ids = tokenizer.encode(text, add_special_tokens=False)
841
-
842
- if answer is not None:
843
- answer_ids = tokenizer.encode(answer, add_special_tokens=False)
844
- answer_ids += [tokenizer.eos_token_id]
845
- text_ids += answer_ids
846
-
847
- input_ids += text_ids
848
- token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
849
- attention_mask = [1] * len(input_ids)
850
- if answer is not None:
851
- labels = [-100 for _ in range(len(input_ids) - len(answer_ids))] + answer_ids
852
- labels = torch.tensor(labels, dtype=torch.long)
853
- else:
854
- labels = None
855
-
856
- return {
857
- 'input_ids': torch.tensor(input_ids, dtype=torch.long),
858
- 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
859
- 'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
860
- 'images': images,
861
- 'labels': labels,
862
- }
 
1
+ """largely copy from llama and adapt for cogvlm"""
2
+ import warnings
3
+ from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any
4
+
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import CrossEntropyLoss
9
+ from torchvision import transforms
10
+ from einops import rearrange
11
+ import transformers
12
+ from transformers import PreTrainedModel, PreTrainedTokenizer
13
+ from transformers.utils.logging import get_logger
14
+ from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
+ from torchvision.transforms import Lambda
17
+ from torchvision.transforms._transforms_video import NormalizeVideo, CenterCropVideo
18
+ from pytorchvideo.transforms import ShortSideScale
19
+ from .configuration_cogvlm import CogVLMConfig
20
+ from .util import FastRotaryEmbedding
21
+ from .visual import EVA2CLIPModel
22
+
23
+ if TYPE_CHECKING:
24
+ from transformers.utils import ModelOutput
25
+
26
+ logger = get_logger(__name__)
27
+
28
+ LANGUAGE_TOKEN_TYPE = 0
29
+ VISION_TOKEN_TYPE = 1
30
+
31
+
32
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
33
+ def _make_causal_mask(
34
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
35
+ ):
36
+ """
37
+ Make causal mask used for bi-directional self-attention.
38
+ """
39
+ bsz, tgt_len = input_ids_shape
40
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
41
+ mask_cond = torch.arange(mask.size(-1), device=device)
42
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
43
+ mask = mask.to(dtype)
44
+
45
+ if past_key_values_length > 0:
46
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
47
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
48
+
49
+
50
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
51
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
52
+ """
53
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
54
+ """
55
+ bsz, src_len = mask.size()
56
+ tgt_len = tgt_len if tgt_len is not None else src_len
57
+
58
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
59
+
60
+ inverted_mask = 1.0 - expanded_mask
61
+
62
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
63
+
64
+
65
+ class RMSNorm(nn.Module):
66
+ def __init__(self, hidden_size, eps=1e-5):
67
+ super().__init__()
68
+ self.weight = nn.Parameter(torch.ones(hidden_size))
69
+ self.variance_epsilon = eps
70
+
71
+ def forward(self, hidden_states):
72
+ input_dtype = hidden_states.dtype
73
+ hidden_states = hidden_states.to(torch.float32)
74
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
75
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
76
+ return (self.weight * hidden_states).to(input_dtype)
77
+
78
+
79
+ class MLP(nn.Module):
80
+ def __init__(self, config):
81
+ super().__init__()
82
+ self.hidden_size = config.hidden_size
83
+ self.intermediate_size = config.intermediate_size
84
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
85
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
86
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
87
+ self.act_fn = ACT2FN[config.hidden_act]
88
+
89
+ def forward(self, x):
90
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
91
+ return down_proj
92
+
93
+
94
+ def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
95
+ vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
96
+ vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (
97
+ token_type_ids[:, 1:] == VISION_TOKEN_TYPE)
98
+ language_token_mask = ~vision_token_mask
99
+ return vision_token_mask, language_token_mask
100
+
101
+
102
+ class VisionExpertMLP(nn.Module):
103
+ def __init__(self, config):
104
+ super().__init__()
105
+ self.language_mlp = MLP(config)
106
+ # self.vision_mlp = MLP(config)
107
+
108
+ def forward(self, hidden_states: "torch.Tensor(B, L, D)", token_type_ids: "torch.LongTensor(B, L)"):
109
+ # output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device)
110
+ # vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
111
+ # output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
112
+ # output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
113
+
114
+ output = self.language_mlp(hidden_states)
115
+ return output
116
+
117
+
118
+ def attention_fn(
119
+ query_layer: "torch.tensor(B, H, L, HD)",
120
+ key_layer: "torch.tensor(B, H, L, HD)",
121
+ value_layer: "torch.tensor(B, H, L, HD)",
122
+ attention_mask: "torch.tensor(B, H, L, HD)",
123
+ *,
124
+ scaling_attention_score: bool = True,
125
+ attention_dropout: nn.Module = None
126
+ ):
127
+ attention_mask_bool = (attention_mask == 0)
128
+ is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
129
+ is_full = (attention_mask_bool > 0).all()
130
+ if not (int(torch.__version__.split('.')[0]) >= 2):
131
+ warnings.warn("It's recommended to use torch2.0 or higher.")
132
+ if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
133
+ dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
134
+ return torch.nn.functional.scaled_dot_product_attention(
135
+ query_layer, key_layer, value_layer,
136
+ attn_mask=None,
137
+ dropout_p=dropout_p,
138
+ is_causal=not is_full
139
+ )
140
+ else:
141
+ if scaling_attention_score:
142
+ query_layer = query_layer / math.sqrt(query_layer.shape[-1])
143
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
144
+ attention_scores = attention_scores + attention_mask
145
+ attention_scores = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
146
+ if attention_dropout is not None:
147
+ attention_scores = attention_dropout(attention_scores)
148
+ context_layer = torch.matmul(attention_scores, value_layer)
149
+ return context_layer
150
+
151
+
152
+ class VisionExpertAttention(nn.Module):
153
+ def __init__(self, config):
154
+ super().__init__()
155
+ self.config = config
156
+ self.hidden_size = config.hidden_size
157
+ self.num_attention_heads = config.num_attention_heads
158
+ self.num_multi_query_heads = config.num_multi_query_heads
159
+ self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
160
+ self.stride = [self.num_attention_heads, self.num_multi_query_heads, self.num_multi_query_heads]
161
+ self.qkv_size = self.hidden_size + self.hidden_size_per_attention_head * self.num_multi_query_heads * 2
162
+ self.head_dim = self.hidden_size // self.num_attention_heads
163
+ self.max_position_embeddings = config.max_position_embeddings
164
+ self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False, base=500000)
165
+ # self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.qkv_size, bias=True)
166
+ # self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
167
+ self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.qkv_size, bias=False)
168
+ self.language_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
169
+
170
+ def _transpose_for_scores(self, tensor):
171
+ """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
172
+ new_tensor_shape = tensor.size()[:-1] + \
173
+ (-1, # flexible for multi-query
174
+ self.hidden_size_per_attention_head)
175
+ tensor = tensor.view(*new_tensor_shape)
176
+ return tensor.permute(0, 2, 1, 3)
177
+
178
+ def forward(
179
+ self,
180
+ hidden_states: torch.Tensor,
181
+ token_type_ids: torch.LongTensor,
182
+ position_ids: torch.LongTensor,
183
+ attention_mask: Optional[torch.Tensor] = None,
184
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
185
+ output_attentions: bool = False,
186
+ use_cache: bool = False,
187
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
188
+ bsz, q_len, _ = hidden_states.size()
189
+ # vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
190
+
191
+ shape = list(hidden_states.shape)
192
+ shape[-1] = self.qkv_size
193
+ # mixed_raw_layer = torch.empty(shape, dtype=hidden_states.dtype, device=hidden_states.device)
194
+ # mixed_raw_layer[vision_token_mask] = self.vision_expert_query_key_value(hidden_states[vision_token_mask])
195
+ # mixed_raw_layer[language_token_mask] = self.language_expert_query_key_value(hidden_states[language_token_mask])
196
+ mixed_raw_layer = self.language_expert_query_key_value(hidden_states)
197
+
198
+ # query_states, key_states, value_states = torch.split(mixed_raw_layer, self.hidden_size, dim=-1)
199
+ factor = mixed_raw_layer.size()[-1] // sum(self.stride)
200
+ query_states, key_states, value_states = torch.split(mixed_raw_layer, [factor * x for x in self.stride], dim=-1)
201
+
202
+ query_states = self._transpose_for_scores(query_states) # B, H, L, HD
203
+ key_states = self._transpose_for_scores(key_states) # B, H, L, HD
204
+ value_states = self._transpose_for_scores(value_states) # B, H, L, HD
205
+
206
+ kv_seq_len = key_states.shape[-2]
207
+ if past_key_value is not None:
208
+ kv_seq_len += past_key_value[0].shape[-2]
209
+
210
+ query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids,
211
+ max_seqlen=position_ids.max() + 1)
212
+
213
+ if past_key_value is not None:
214
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
215
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
216
+
217
+ past_key_value = (key_states, value_states) if use_cache else None
218
+
219
+ key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads, -1,
220
+ -1).contiguous().view(
221
+ bsz, self.num_attention_heads, *key_states.shape[2:])
222
+ value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads,
223
+ -1,
224
+ -1).contiguous().view(bsz, self.num_attention_heads,
225
+ *value_states.shape[2:])
226
+
227
+ context_layer = attention_fn(
228
+ query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
229
+ scaling_attention_score=True, attention_dropout=None)
230
+ if context_layer.size() != (bsz, self.num_attention_heads, q_len, self.head_dim):
231
+ raise ValueError(
232
+ f"`attn_output` should be of size {(bsz, self.num_attention_heads, q_len, self.head_dim)}, but is"
233
+ f" {context_layer.size()}"
234
+ )
235
+ context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
236
+
237
+ # attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
238
+ # attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
239
+ # attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
240
+
241
+ attn_output = self.language_expert_dense(context_layer)
242
+
243
+ if output_attentions:
244
+ warnings.warn("output_attentions is not implemented.")
245
+
246
+ return attn_output, None, past_key_value
247
+
248
+
249
+ class CogVLMDecoderLayer(nn.Module):
250
+ def __init__(self, config):
251
+ super().__init__()
252
+ self.hidden_size = config.hidden_size
253
+ self.self_attn = VisionExpertAttention(config=config)
254
+ self.mlp = VisionExpertMLP(config)
255
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
256
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
257
+
258
+ def forward(
259
+ self,
260
+ hidden_states: torch.Tensor,
261
+ token_type_ids: torch.LongTensor,
262
+ position_ids: torch.LongTensor,
263
+ attention_mask: Optional[torch.Tensor] = None,
264
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
265
+ output_attentions: Optional[bool] = False,
266
+ use_cache: Optional[bool] = False,
267
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
268
+ residual = hidden_states
269
+
270
+ hidden_states = self.input_layernorm(hidden_states)
271
+
272
+ # Self Attention
273
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
274
+ hidden_states=hidden_states,
275
+ token_type_ids=token_type_ids,
276
+ position_ids=position_ids,
277
+ attention_mask=attention_mask,
278
+ past_key_value=past_key_value,
279
+ output_attentions=output_attentions,
280
+ use_cache=use_cache,
281
+ )
282
+ hidden_states = residual + hidden_states
283
+
284
+ # Fully Connected
285
+ residual = hidden_states
286
+ hidden_states = self.post_attention_layernorm(hidden_states)
287
+ hidden_states = self.mlp(hidden_states, token_type_ids=token_type_ids)
288
+ hidden_states = residual + hidden_states
289
+
290
+ outputs = (hidden_states,)
291
+
292
+ if output_attentions:
293
+ outputs += (self_attn_weights,)
294
+
295
+ if use_cache:
296
+ outputs += (present_key_value,)
297
+
298
+ return outputs # type: ignore
299
+
300
+
301
+ class CogVLMPreTrainedModel(PreTrainedModel):
302
+ config_class = CogVLMConfig
303
+ base_model_prefix = "model"
304
+ supports_gradient_checkpointing = False
305
+ _no_split_modules = ["CogVLMDecoderLayer"]
306
+ _skip_keys_device_placement = "past_key_values"
307
+
308
+ def _init_weights(self, module):
309
+ std = self.config.initializer_range
310
+ if isinstance(module, nn.Linear):
311
+ module.weight.data.normal_(mean=0.0, std=std)
312
+ if module.bias is not None:
313
+ module.bias.data.zero_()
314
+ elif isinstance(module, nn.Embedding):
315
+ module.weight.data.normal_(mean=0.0, std=std)
316
+ if module.padding_idx is not None:
317
+ module.weight.data[module.padding_idx].zero_()
318
+
319
+
320
+ def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
321
+ if images_list is None or len(images_list) == 0:
322
+ return True
323
+ for image_list in images_list:
324
+ if len(image_list):
325
+ return False
326
+ return True
327
+
328
+
329
+ def build_position_ids(x: "torch.BoolTensor(B, L)",
330
+ attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)":
331
+ if attention_mask is not None:
332
+ tmp = x.clone()
333
+ tmp[~(attention_mask.bool())] = -1
334
+ else:
335
+ tmp = x.clone()
336
+ # image boi eoi token as LANGUAGE_TOKEN_TYPE
337
+ is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
338
+ is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
339
+ is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
340
+ is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
341
+ is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
342
+ tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
343
+ # final position ids
344
+ y = torch.zeros_like(x, dtype=torch.long)
345
+ y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
346
+ (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
347
+ y = y.cumsum(dim=-1)
348
+ return y
349
+
350
+
351
+ class CogVLMVideoModel(CogVLMPreTrainedModel):
352
+ def __init__(self, config):
353
+ super().__init__(config)
354
+ self.padding_idx = 128002
355
+ self.vocab_size = config.vocab_size
356
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
357
+ self.layers = nn.ModuleList([CogVLMDecoderLayer(config) for _ in range(config.num_hidden_layers)])
358
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
359
+
360
+ self.vision = EVA2CLIPModel(config)
361
+
362
+ self.gradient_checkpointing = False
363
+ # Initialize weights and apply final processing
364
+ self.post_init()
365
+
366
+ def encode_images(self, images: List[List[torch.Tensor]], ) -> torch.Tensor:
367
+ images_list, images = images, []
368
+
369
+ images = []
370
+ for image_list in images_list:
371
+ for image in image_list:
372
+ images.append(image)
373
+
374
+ # images = torch.stack(images) # video images is already stacked
375
+ images_features = self.vision(images[0])
376
+ return images_features
377
+
378
+ def forward(
379
+ self,
380
+ input_ids: torch.LongTensor = None,
381
+ images: List[List[torch.Tensor]] = None,
382
+ token_type_ids: Optional[torch.LongTensor] = None,
383
+ attention_mask: Optional[torch.Tensor] = None,
384
+ position_ids: Optional[torch.LongTensor] = None,
385
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
386
+ inputs_embeds: Optional[torch.FloatTensor] = None,
387
+ use_cache: Optional[bool] = None,
388
+ output_attentions: Optional[bool] = None,
389
+ output_hidden_states: Optional[bool] = None,
390
+ return_dict: Optional[bool] = None,
391
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
392
+ """take care of image_encode, token_type_ids, position_ids and (attention_mask = None is fine)"""
393
+
394
+ if past_key_values is not None:
395
+ pass # generate mode with past_key_values. the image features are already mapped
396
+ else:
397
+ # not allow for inputs_embeds, because we want to process image feature
398
+ assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
399
+ if not is_empty(images): # multi-modality
400
+ assert token_type_ids is not None, f"multi-modality requires `token_type_ids`!"
401
+ assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
402
+ inputs_embeds = self.embed_tokens(input_ids)
403
+ images_features = self.encode_images(images)
404
+ images_features = rearrange(images_features, 'b n d -> (b n) d')
405
+ images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
406
+ inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
407
+ else: # single-modality
408
+ if token_type_ids is None:
409
+ token_type_ids = torch.ones_like(input_ids, dtype=torch.long,
410
+ device=input_ids.device) * LANGUAGE_TOKEN_TYPE
411
+ assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
412
+ inputs_embeds = self.embed_tokens(input_ids)
413
+
414
+ if position_ids is None:
415
+ position_ids = build_position_ids(token_type_ids, attention_mask)
416
+ input_ids = None
417
+ return self.llm_forward(
418
+ input_ids=input_ids,
419
+ token_type_ids=token_type_ids,
420
+ attention_mask=attention_mask,
421
+ position_ids=position_ids,
422
+ past_key_values=past_key_values,
423
+ inputs_embeds=inputs_embeds,
424
+ use_cache=use_cache,
425
+ output_attentions=output_attentions,
426
+ output_hidden_states=output_hidden_states,
427
+ return_dict=return_dict,
428
+ )
429
+
430
+ def llm_forward(
431
+ self,
432
+ input_ids: torch.LongTensor = None,
433
+ token_type_ids: torch.LongTensor = None,
434
+ attention_mask: Optional[torch.Tensor] = None,
435
+ position_ids: Optional[torch.LongTensor] = None,
436
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
437
+ inputs_embeds: Optional[torch.FloatTensor] = None,
438
+ use_cache: Optional[bool] = None,
439
+ output_attentions: Optional[bool] = None,
440
+ output_hidden_states: Optional[bool] = None,
441
+ return_dict: Optional[bool] = None,
442
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
443
+ """largely copy from llama forward and adapt for cogvlm with `token_type_ids`"""
444
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
445
+ output_hidden_states = (
446
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
447
+ )
448
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
449
+
450
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
451
+
452
+ # retrieve input_ids and inputs_embeds
453
+ if input_ids is not None and inputs_embeds is not None:
454
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
455
+ elif input_ids is not None:
456
+ batch_size, seq_length = input_ids.shape
457
+ elif inputs_embeds is not None:
458
+ batch_size, seq_length, _ = inputs_embeds.shape
459
+ else:
460
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
461
+
462
+ seq_length_with_past = seq_length
463
+ past_key_values_length = 0
464
+
465
+ if past_key_values is not None:
466
+ past_key_values_length = past_key_values[0][0].shape[2]
467
+ seq_length_with_past = seq_length_with_past + past_key_values_length
468
+
469
+ if position_ids is None:
470
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
471
+ position_ids = torch.arange(
472
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
473
+ )
474
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
475
+ else:
476
+ position_ids = position_ids.view(-1, seq_length).long()
477
+
478
+ if inputs_embeds is None:
479
+ inputs_embeds = self.embed_tokens(input_ids)
480
+ # embed positions
481
+ if attention_mask is None:
482
+ attention_mask = torch.ones(
483
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
484
+ )
485
+ attention_mask = self._prepare_decoder_attention_mask(
486
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
487
+ )
488
+
489
+ hidden_states = inputs_embeds
490
+
491
+ # decoder layers
492
+ all_hidden_states = () if output_hidden_states else None
493
+ all_self_attns = () if output_attentions else None
494
+ next_decoder_cache = () if use_cache else None
495
+
496
+ for idx, decoder_layer in enumerate(self.layers):
497
+ if output_hidden_states:
498
+ all_hidden_states += (hidden_states,)
499
+
500
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
501
+ layer_outputs = decoder_layer(
502
+ hidden_states,
503
+ token_type_ids=token_type_ids,
504
+ attention_mask=attention_mask,
505
+ position_ids=position_ids,
506
+ past_key_value=past_key_value,
507
+ output_attentions=output_attentions,
508
+ use_cache=use_cache,
509
+ )
510
+ hidden_states = layer_outputs[0]
511
+
512
+ if use_cache:
513
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
514
+
515
+ if output_attentions:
516
+ all_self_attns += (layer_outputs[1],)
517
+
518
+ hidden_states = self.norm(hidden_states)
519
+
520
+ # add hidden states from the last decoder layer
521
+ if output_hidden_states:
522
+ all_hidden_states += (hidden_states,)
523
+
524
+ next_cache = next_decoder_cache if use_cache else None
525
+ if not return_dict:
526
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
527
+ return BaseModelOutputWithPast(
528
+ last_hidden_state=hidden_states,
529
+ past_key_values=next_cache,
530
+ hidden_states=all_hidden_states,
531
+ attentions=all_self_attns,
532
+ )
533
+
534
+ def get_input_embeddings(self):
535
+ return self.embed_tokens
536
+
537
+ def set_input_embeddings(self, value):
538
+ self.embed_tokens = value
539
+
540
+ # noinspection PyMethodMayBeStatic
541
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
542
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
543
+ # create causal mask
544
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
545
+ combined_attention_mask = None
546
+ if input_shape[-1] > 1:
547
+ combined_attention_mask = _make_causal_mask(
548
+ input_shape,
549
+ inputs_embeds.dtype,
550
+ device=inputs_embeds.device,
551
+ past_key_values_length=past_key_values_length,
552
+ )
553
+
554
+ if attention_mask is not None:
555
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
556
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
557
+ inputs_embeds.device
558
+ )
559
+ combined_attention_mask = (
560
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
561
+ )
562
+
563
+ return combined_attention_mask
564
+
565
+
566
+ def _history_to_prompt(signal_type, history, query):
567
+ if signal_type == 'base':
568
+ return query
569
+ elif signal_type == 'vqa':
570
+ answer_format = 'Short answer:'
571
+ elif signal_type == 'chat':
572
+ answer_format = 'Answer:'
573
+ else:
574
+ assert False, f"Unknown signal type {signal_type}"
575
+
576
+ prompt = ''
577
+ for i, (old_query, response) in enumerate(history):
578
+ prompt += 'Question: ' + old_query + " {} ".format(answer_format) + response + "\n"
579
+ prompt += 'Question: {} {}'.format(query, answer_format)
580
+ return prompt
581
+
582
+
583
+ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
584
+ _auto_class = "AutoModelForCausalLM"
585
+
586
+ def __init__(self, config):
587
+ super().__init__(config)
588
+ self.model = CogVLMVideoModel(config)
589
+ self.vocab_size = config.vocab_size
590
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
591
+ self.video_downsample = 1 # TODO: change this to config
592
+
593
+ # Initialize weights and apply final processing
594
+ self.post_init()
595
+
596
+ def get_input_embeddings(self):
597
+ return self.model.embed_tokens
598
+
599
+ def set_input_embeddings(self, value):
600
+ self.model.embed_tokens = value
601
+
602
+ def get_output_embeddings(self):
603
+ return self.lm_head
604
+
605
+ def set_output_embeddings(self, new_embeddings):
606
+ self.lm_head = new_embeddings
607
+
608
+ def set_decoder(self, decoder):
609
+ self.model = decoder
610
+
611
+ def get_decoder(self):
612
+ return self.model
613
+
614
+ def forward(
615
+ self,
616
+ input_ids: torch.LongTensor = None,
617
+ images: List[List[torch.Tensor]] = None,
618
+ token_type_ids: Optional[torch.LongTensor] = None,
619
+ attention_mask: Optional[torch.Tensor] = None,
620
+ position_ids: Optional[torch.LongTensor] = None,
621
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
622
+ inputs_embeds: Optional[torch.FloatTensor] = None,
623
+ use_cache: Optional[bool] = None,
624
+ output_attentions: Optional[bool] = None,
625
+ output_hidden_states: Optional[bool] = None,
626
+ return_dict: Optional[bool] = None,
627
+ labels: Optional[torch.LongTensor] = None,
628
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
629
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
630
+ output_hidden_states = (
631
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
632
+ )
633
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
634
+
635
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
636
+ outputs = self.model(
637
+ input_ids=input_ids,
638
+ images=images,
639
+ token_type_ids=token_type_ids,
640
+ attention_mask=attention_mask,
641
+ position_ids=position_ids,
642
+ past_key_values=past_key_values,
643
+ inputs_embeds=inputs_embeds,
644
+ use_cache=use_cache,
645
+ output_attentions=output_attentions,
646
+ output_hidden_states=output_hidden_states,
647
+ return_dict=return_dict,
648
+ )
649
+
650
+ hidden_states = outputs[0]
651
+ logits = self.lm_head(hidden_states)
652
+ logits = logits.float()
653
+
654
+ loss = None
655
+ if labels is not None:
656
+ # Shift so that tokens < n predict n
657
+ shift_logits = logits[..., :-1, :].contiguous()
658
+ shift_labels = labels[..., 1:].contiguous()
659
+ # Flatten the tokens
660
+ loss_fct = CrossEntropyLoss()
661
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
662
+ shift_labels = shift_labels.view(-1)
663
+ # Enable model parallelism
664
+ shift_labels = shift_labels.to(shift_logits.device)
665
+ loss = loss_fct(shift_logits, shift_labels)
666
+
667
+ if not return_dict:
668
+ output = (logits,) + outputs[1:]
669
+ return (loss,) + output if loss is not None else output
670
+
671
+ return CausalLMOutputWithPast(
672
+ loss=loss,
673
+ logits=logits,
674
+ past_key_values=outputs.past_key_values,
675
+ hidden_states=outputs.hidden_states,
676
+ attentions=outputs.attentions,
677
+ )
678
+
679
+ def _prepare_attention_mask_for_generation(
680
+ self,
681
+ inputs: torch.Tensor,
682
+ pad_token_id: Optional[int],
683
+ eos_token_id: Optional[Union[int, List[int]]],
684
+ ) -> torch.LongTensor:
685
+ return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
686
+
687
+ def prepare_inputs_for_generation(
688
+ self, input_ids, token_type_ids, images=None, past_key_values=None, attention_mask=None, inputs_embeds=None,
689
+ **kwargs
690
+ ):
691
+ # build position_ids if needed
692
+ position_ids = kwargs.get("position_ids", None)
693
+ if position_ids is None:
694
+ position_ids = build_position_ids(token_type_ids, attention_mask)
695
+
696
+ if past_key_values:
697
+ input_ids = input_ids[:, -1:]
698
+ token_type_ids = token_type_ids[:, -1:]
699
+ position_ids = position_ids[:, -1:]
700
+
701
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
702
+ if inputs_embeds is not None and past_key_values is None:
703
+ model_inputs = {"inputs_embeds": inputs_embeds}
704
+ else:
705
+ model_inputs = {"input_ids": input_ids}
706
+
707
+ model_inputs.update(
708
+ {
709
+ "token_type_ids": token_type_ids,
710
+ "images": images,
711
+ "position_ids": position_ids,
712
+ "past_key_values": past_key_values,
713
+ "use_cache": kwargs.get("use_cache"),
714
+ "attention_mask": attention_mask,
715
+ }
716
+ )
717
+ return model_inputs
718
+
719
+ def _update_model_kwargs_for_generation(
720
+ self,
721
+ outputs: "ModelOutput",
722
+ model_kwargs: Dict[str, Any],
723
+ is_encoder_decoder: bool = False,
724
+ standardize_cache_format: bool = False,
725
+ ) -> Dict[str, Any]:
726
+ # update past_key_values
727
+ if transformers.__version__ >= "4.44.0":
728
+ cache_name, cache = self._extract_past_from_model_output(
729
+ outputs
730
+ )
731
+ else:
732
+ cache_name, cache = self._extract_past_from_model_output(
733
+ outputs, standardize_cache_format=standardize_cache_format
734
+ )
735
+ model_kwargs[cache_name] = cache
736
+
737
+ if getattr(outputs, "state", None) is not None:
738
+ model_kwargs["state"] = outputs.state
739
+
740
+ # update token_type_ids with last value
741
+ if "token_type_ids" in model_kwargs:
742
+ token_type_ids = model_kwargs["token_type_ids"]
743
+ new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype,
744
+ device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
745
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
746
+
747
+ if not is_encoder_decoder:
748
+ # update attention mask
749
+ if "attention_mask" in model_kwargs:
750
+ attention_mask = model_kwargs["attention_mask"]
751
+ model_kwargs["attention_mask"] = torch.cat(
752
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
753
+ )
754
+ else:
755
+ # update decoder attention mask
756
+ if "decoder_attention_mask" in model_kwargs:
757
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
758
+ model_kwargs["decoder_attention_mask"] = torch.cat(
759
+ [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
760
+ dim=-1,
761
+ )
762
+
763
+ return model_kwargs
764
+
765
+ def _reorder_cache(self, past_key_values, beam_idx):
766
+ reordered_past = ()
767
+ for layer_past in past_key_values:
768
+ reordered_past += (
769
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
770
+ )
771
+ return reordered_past
772
+
773
+ def build_conversation_input_ids(
774
+ self,
775
+ tokenizer: "PreTrainedTokenizer",
776
+ *,
777
+ query: str,
778
+ history: Optional[List[Tuple[str, str]]] = None,
779
+ images: Optional[List["PIL.Image"]] = None,
780
+ template_version: Optional[Literal["base", "chat", "vqa"]] = None,
781
+ answer: str = None,
782
+ ):
783
+ image_size: int = self.config.vision_config['image_size']
784
+ template_version = template_version or self.config.template_version
785
+ assert images is None or len(images) <= 1, f"not support multi images by now."
786
+ history = history or []
787
+ text = _history_to_prompt(template_version, history, query)
788
+ input_ids = [tokenizer.bos_token_id]
789
+ token_type_ids = [LANGUAGE_TOKEN_TYPE]
790
+ add_time_indices = True if template_version == 'chat' else False
791
+ if images is not None and len(images) == 1:
792
+ # vision
793
+ transform = transforms.Compose(
794
+ [
795
+ # UniformTemporalSubsample(num_frames),
796
+ Lambda(lambda x: x / 255.0),
797
+ NormalizeVideo(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
798
+ ShortSideScale(size=image_size),
799
+ CenterCropVideo(image_size),
800
+ # RandomHorizontalFlipVideo(p=0.5),
801
+ ]
802
+ )
803
+ images = [transform(images[0]).transpose(0, 1)] # (T, C, H, W)
804
+ num_eois = len(images[0])
805
+ tokenizer.pad_token_id = 128002
806
+ if not add_time_indices:
807
+ vision_token_num = (64 + 2) * num_eois
808
+ input_ids += [tokenizer.pad_token_id] * vision_token_num # add spetial token
809
+ token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
810
+ else:
811
+ video_ids, video_type_ids = [], []
812
+ sing_vision_token_num = (64 + 2)
813
+ for _time_idx in range(num_eois):
814
+ video_ids += [tokenizer.pad_token_id] * sing_vision_token_num
815
+ video_type_ids += [VISION_TOKEN_TYPE] * sing_vision_token_num
816
+ # add time indices
817
+ time_indices = tokenizer.encode(str(_time_idx), add_special_tokens=False)
818
+ video_ids += time_indices
819
+ video_type_ids += [LANGUAGE_TOKEN_TYPE] * len(time_indices)
820
+ # llama3 adapt for cogvlm
821
+ input_ids += video_ids
822
+ token_type_ids += video_type_ids
823
+
824
+ text_ids = tokenizer.encode(text, add_special_tokens=False)
825
+
826
+ if answer is not None:
827
+ answer_ids = tokenizer.encode(answer, add_special_tokens=False)
828
+ answer_ids += [tokenizer.eos_token_id]
829
+ text_ids += answer_ids
830
+
831
+ input_ids += text_ids
832
+ token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
833
+ attention_mask = [1] * len(input_ids)
834
+ if answer is not None:
835
+ labels = [-100 for _ in range(len(input_ids) - len(answer_ids))] + answer_ids
836
+ labels = torch.tensor(labels, dtype=torch.long)
837
+ else:
838
+ labels = None
839
+
840
+ return {
841
+ 'input_ids': torch.tensor(input_ids, dtype=torch.long),
842
+ 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
843
+ 'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
844
+ 'images': images,
845
+ 'labels': labels,
846
+ }