harelcain commited on
Commit
1b45a45
·
verified ·
1 Parent(s): f60e47a

Update modeling_cogvlm.py to support newer diffusers versions

Browse files

Similarly to https://huggingface.co/allenai/Molmo-7B-D-0924/discussions/43/files and https://huggingface.co/THUDM/chatglm3-6b/commit/67d005d386a01d4825649743f41e90f83edd6094, need to support breaking change in newer versions of diffusers.

Files changed (1) hide show
  1. modeling_cogvlm.py +840 -840
modeling_cogvlm.py CHANGED
@@ -1,840 +1,840 @@
1
- """largely copy from llama and adapt for cogvlm"""
2
- import warnings
3
- from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any
4
-
5
- import math
6
- import torch
7
- from torch import nn
8
- from torch.nn import CrossEntropyLoss
9
- from torchvision import transforms
10
- from einops import rearrange
11
- from transformers import PreTrainedModel, PreTrainedTokenizer
12
- from transformers.utils.logging import get_logger
13
- from transformers.activations import ACT2FN
14
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
15
- from torchvision.transforms import Lambda
16
- from torchvision.transforms._transforms_video import NormalizeVideo, CenterCropVideo
17
- from pytorchvideo.transforms import ShortSideScale
18
- from .configuration_cogvlm import CogVLMConfig
19
- from .util import FastRotaryEmbedding
20
- from .visual import EVA2CLIPModel
21
-
22
- if TYPE_CHECKING:
23
- from transformers.utils import ModelOutput
24
-
25
- logger = get_logger(__name__)
26
-
27
- LANGUAGE_TOKEN_TYPE = 0
28
- VISION_TOKEN_TYPE = 1
29
-
30
-
31
- # Copied from transformers.models.bart.modeling_bart._make_causal_mask
32
- def _make_causal_mask(
33
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
34
- ):
35
- """
36
- Make causal mask used for bi-directional self-attention.
37
- """
38
- bsz, tgt_len = input_ids_shape
39
- mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
40
- mask_cond = torch.arange(mask.size(-1), device=device)
41
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
42
- mask = mask.to(dtype)
43
-
44
- if past_key_values_length > 0:
45
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
46
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
47
-
48
-
49
- # Copied from transformers.models.bart.modeling_bart._expand_mask
50
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
51
- """
52
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
53
- """
54
- bsz, src_len = mask.size()
55
- tgt_len = tgt_len if tgt_len is not None else src_len
56
-
57
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
58
-
59
- inverted_mask = 1.0 - expanded_mask
60
-
61
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
62
-
63
-
64
- class RMSNorm(nn.Module):
65
- def __init__(self, hidden_size, eps=1e-5):
66
- super().__init__()
67
- self.weight = nn.Parameter(torch.ones(hidden_size))
68
- self.variance_epsilon = eps
69
-
70
- def forward(self, hidden_states):
71
- input_dtype = hidden_states.dtype
72
- hidden_states = hidden_states.to(torch.float32)
73
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
74
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
75
- return (self.weight * hidden_states).to(input_dtype)
76
-
77
-
78
- class MLP(nn.Module):
79
- def __init__(self, config):
80
- super().__init__()
81
- self.hidden_size = config.hidden_size
82
- self.intermediate_size = config.intermediate_size
83
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
84
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
85
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
86
- self.act_fn = ACT2FN[config.hidden_act]
87
-
88
- def forward(self, x):
89
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
90
- return down_proj
91
-
92
-
93
- def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
94
- vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
95
- vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (
96
- token_type_ids[:, 1:] == VISION_TOKEN_TYPE)
97
- language_token_mask = ~vision_token_mask
98
- return vision_token_mask, language_token_mask
99
-
100
-
101
- class VisionExpertMLP(nn.Module):
102
- def __init__(self, config):
103
- super().__init__()
104
- self.language_mlp = MLP(config)
105
- # self.vision_mlp = MLP(config)
106
-
107
- def forward(self, hidden_states: "torch.Tensor(B, L, D)", token_type_ids: "torch.LongTensor(B, L)"):
108
- # output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device)
109
- # vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
110
- # output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
111
- # output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
112
-
113
- output = self.language_mlp(hidden_states)
114
- return output
115
-
116
-
117
- def attention_fn(
118
- query_layer: "torch.tensor(B, H, L, HD)",
119
- key_layer: "torch.tensor(B, H, L, HD)",
120
- value_layer: "torch.tensor(B, H, L, HD)",
121
- attention_mask: "torch.tensor(B, H, L, HD)",
122
- *,
123
- scaling_attention_score: bool = True,
124
- attention_dropout: nn.Module = None
125
- ):
126
- attention_mask_bool = (attention_mask == 0)
127
- is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
128
- is_full = (attention_mask_bool > 0).all()
129
- if not (int(torch.__version__.split('.')[0]) >= 2):
130
- warnings.warn("It's recommended to use torch2.0 or higher.")
131
- if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
132
- dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
133
- return torch.nn.functional.scaled_dot_product_attention(
134
- query_layer, key_layer, value_layer,
135
- attn_mask=None,
136
- dropout_p=dropout_p,
137
- is_causal=not is_full
138
- )
139
- else:
140
- if scaling_attention_score:
141
- query_layer = query_layer / math.sqrt(query_layer.shape[-1])
142
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
143
- attention_scores = attention_scores + attention_mask
144
- attention_scores = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
145
- if attention_dropout is not None:
146
- attention_scores = attention_dropout(attention_scores)
147
- context_layer = torch.matmul(attention_scores, value_layer)
148
- return context_layer
149
-
150
-
151
- class VisionExpertAttention(nn.Module):
152
- def __init__(self, config):
153
- super().__init__()
154
- self.config = config
155
- self.hidden_size = config.hidden_size
156
- self.num_attention_heads = config.num_attention_heads
157
- self.num_multi_query_heads = config.num_multi_query_heads
158
- self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
159
- self.stride = [self.num_attention_heads, self.num_multi_query_heads, self.num_multi_query_heads]
160
- self.qkv_size = self.hidden_size + self.hidden_size_per_attention_head * self.num_multi_query_heads * 2
161
- self.head_dim = self.hidden_size // self.num_attention_heads
162
- self.max_position_embeddings = config.max_position_embeddings
163
- self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False, base=500000)
164
- # self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.qkv_size, bias=True)
165
- # self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
166
- self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.qkv_size, bias=False)
167
- self.language_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
168
-
169
- def _transpose_for_scores(self, tensor):
170
- """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
171
- new_tensor_shape = tensor.size()[:-1] + \
172
- (-1, # flexible for multi-query
173
- self.hidden_size_per_attention_head)
174
- tensor = tensor.view(*new_tensor_shape)
175
- return tensor.permute(0, 2, 1, 3)
176
-
177
- def forward(
178
- self,
179
- hidden_states: torch.Tensor,
180
- token_type_ids: torch.LongTensor,
181
- position_ids: torch.LongTensor,
182
- attention_mask: Optional[torch.Tensor] = None,
183
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
184
- output_attentions: bool = False,
185
- use_cache: bool = False,
186
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
187
- bsz, q_len, _ = hidden_states.size()
188
- # vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
189
-
190
- shape = list(hidden_states.shape)
191
- shape[-1] = self.qkv_size
192
- # mixed_raw_layer = torch.empty(shape, dtype=hidden_states.dtype, device=hidden_states.device)
193
- # mixed_raw_layer[vision_token_mask] = self.vision_expert_query_key_value(hidden_states[vision_token_mask])
194
- # mixed_raw_layer[language_token_mask] = self.language_expert_query_key_value(hidden_states[language_token_mask])
195
- mixed_raw_layer = self.language_expert_query_key_value(hidden_states)
196
-
197
- # query_states, key_states, value_states = torch.split(mixed_raw_layer, self.hidden_size, dim=-1)
198
- factor = mixed_raw_layer.size()[-1] // sum(self.stride)
199
- query_states, key_states, value_states = torch.split(mixed_raw_layer, [factor * x for x in self.stride], dim=-1)
200
-
201
- query_states = self._transpose_for_scores(query_states) # B, H, L, HD
202
- key_states = self._transpose_for_scores(key_states) # B, H, L, HD
203
- value_states = self._transpose_for_scores(value_states) # B, H, L, HD
204
-
205
- kv_seq_len = key_states.shape[-2]
206
- if past_key_value is not None:
207
- kv_seq_len += past_key_value[0].shape[-2]
208
-
209
- query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids,
210
- max_seqlen=position_ids.max() + 1)
211
-
212
- if past_key_value is not None:
213
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
214
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
215
-
216
- past_key_value = (key_states, value_states) if use_cache else None
217
-
218
- key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads, -1,
219
- -1).contiguous().view(
220
- bsz, self.num_attention_heads, *key_states.shape[2:])
221
- value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads,
222
- -1,
223
- -1).contiguous().view(bsz, self.num_attention_heads,
224
- *value_states.shape[2:])
225
-
226
- context_layer = attention_fn(
227
- query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
228
- scaling_attention_score=True, attention_dropout=None)
229
- if context_layer.size() != (bsz, self.num_attention_heads, q_len, self.head_dim):
230
- raise ValueError(
231
- f"`attn_output` should be of size {(bsz, self.num_attention_heads, q_len, self.head_dim)}, but is"
232
- f" {context_layer.size()}"
233
- )
234
- context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
235
-
236
- # attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
237
- # attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
238
- # attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
239
-
240
- attn_output = self.language_expert_dense(context_layer)
241
-
242
- if output_attentions:
243
- warnings.warn("output_attentions is not implemented.")
244
-
245
- return attn_output, None, past_key_value
246
-
247
-
248
- class CogVLMDecoderLayer(nn.Module):
249
- def __init__(self, config):
250
- super().__init__()
251
- self.hidden_size = config.hidden_size
252
- self.self_attn = VisionExpertAttention(config=config)
253
- self.mlp = VisionExpertMLP(config)
254
- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
255
- self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
256
-
257
- def forward(
258
- self,
259
- hidden_states: torch.Tensor,
260
- token_type_ids: torch.LongTensor,
261
- position_ids: torch.LongTensor,
262
- attention_mask: Optional[torch.Tensor] = None,
263
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
264
- output_attentions: Optional[bool] = False,
265
- use_cache: Optional[bool] = False,
266
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
267
- residual = hidden_states
268
-
269
- hidden_states = self.input_layernorm(hidden_states)
270
-
271
- # Self Attention
272
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
273
- hidden_states=hidden_states,
274
- token_type_ids=token_type_ids,
275
- position_ids=position_ids,
276
- attention_mask=attention_mask,
277
- past_key_value=past_key_value,
278
- output_attentions=output_attentions,
279
- use_cache=use_cache,
280
- )
281
- hidden_states = residual + hidden_states
282
-
283
- # Fully Connected
284
- residual = hidden_states
285
- hidden_states = self.post_attention_layernorm(hidden_states)
286
- hidden_states = self.mlp(hidden_states, token_type_ids=token_type_ids)
287
- hidden_states = residual + hidden_states
288
-
289
- outputs = (hidden_states,)
290
-
291
- if output_attentions:
292
- outputs += (self_attn_weights,)
293
-
294
- if use_cache:
295
- outputs += (present_key_value,)
296
-
297
- return outputs # type: ignore
298
-
299
-
300
- class CogVLMPreTrainedModel(PreTrainedModel):
301
- config_class = CogVLMConfig
302
- base_model_prefix = "model"
303
- supports_gradient_checkpointing = False
304
- _no_split_modules = ["CogVLMDecoderLayer"]
305
- _skip_keys_device_placement = "past_key_values"
306
-
307
- def _init_weights(self, module):
308
- std = self.config.initializer_range
309
- if isinstance(module, nn.Linear):
310
- module.weight.data.normal_(mean=0.0, std=std)
311
- if module.bias is not None:
312
- module.bias.data.zero_()
313
- elif isinstance(module, nn.Embedding):
314
- module.weight.data.normal_(mean=0.0, std=std)
315
- if module.padding_idx is not None:
316
- module.weight.data[module.padding_idx].zero_()
317
-
318
-
319
- def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
320
- if images_list is None or len(images_list) == 0:
321
- return True
322
- for image_list in images_list:
323
- if len(image_list):
324
- return False
325
- return True
326
-
327
-
328
- def build_position_ids(x: "torch.BoolTensor(B, L)",
329
- attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)":
330
- if attention_mask is not None:
331
- tmp = x.clone()
332
- tmp[~(attention_mask.bool())] = -1
333
- else:
334
- tmp = x.clone()
335
- # image boi eoi token as LANGUAGE_TOKEN_TYPE
336
- is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
337
- is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
338
- is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
339
- is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
340
- is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
341
- tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
342
- # final position ids
343
- y = torch.zeros_like(x, dtype=torch.long)
344
- y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
345
- (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
346
- y = y.cumsum(dim=-1)
347
- return y
348
-
349
-
350
- class CogVLMVideoModel(CogVLMPreTrainedModel):
351
- def __init__(self, config):
352
- super().__init__(config)
353
- self.padding_idx = 128002
354
- self.vocab_size = config.vocab_size
355
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
356
- self.layers = nn.ModuleList([CogVLMDecoderLayer(config) for _ in range(config.num_hidden_layers)])
357
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
358
-
359
- self.vision = EVA2CLIPModel(config)
360
-
361
- self.gradient_checkpointing = False
362
- # Initialize weights and apply final processing
363
- self.post_init()
364
-
365
- def encode_images(self, images: List[List[torch.Tensor]], ) -> torch.Tensor:
366
- images_list, images = images, []
367
-
368
- images = []
369
- for image_list in images_list:
370
- for image in image_list:
371
- images.append(image)
372
-
373
- # images = torch.stack(images) # video images is already stacked
374
- images_features = self.vision(images[0])
375
- return images_features
376
-
377
- def forward(
378
- self,
379
- input_ids: torch.LongTensor = None,
380
- images: List[List[torch.Tensor]] = None,
381
- token_type_ids: Optional[torch.LongTensor] = None,
382
- attention_mask: Optional[torch.Tensor] = None,
383
- position_ids: Optional[torch.LongTensor] = None,
384
- past_key_values: Optional[List[torch.FloatTensor]] = None,
385
- inputs_embeds: Optional[torch.FloatTensor] = None,
386
- use_cache: Optional[bool] = None,
387
- output_attentions: Optional[bool] = None,
388
- output_hidden_states: Optional[bool] = None,
389
- return_dict: Optional[bool] = None,
390
- ) -> Union[Tuple, BaseModelOutputWithPast]:
391
- """take care of image_encode, token_type_ids, position_ids and (attention_mask = None is fine)"""
392
-
393
- if past_key_values is not None:
394
- pass # generate mode with past_key_values. the image features are already mapped
395
- else:
396
- # not allow for inputs_embeds, because we want to process image feature
397
- assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
398
- if not is_empty(images): # multi-modality
399
- assert token_type_ids is not None, f"multi-modality requires `token_type_ids`!"
400
- assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
401
- inputs_embeds = self.embed_tokens(input_ids)
402
- images_features = self.encode_images(images)
403
- images_features = rearrange(images_features, 'b n d -> (b n) d')
404
- images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
405
- inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
406
- else: # single-modality
407
- if token_type_ids is None:
408
- token_type_ids = torch.ones_like(input_ids, dtype=torch.long,
409
- device=input_ids.device) * LANGUAGE_TOKEN_TYPE
410
- assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
411
- inputs_embeds = self.embed_tokens(input_ids)
412
-
413
- if position_ids is None:
414
- position_ids = build_position_ids(token_type_ids, attention_mask)
415
- input_ids = None
416
- return self.llm_forward(
417
- input_ids=input_ids,
418
- token_type_ids=token_type_ids,
419
- attention_mask=attention_mask,
420
- position_ids=position_ids,
421
- past_key_values=past_key_values,
422
- inputs_embeds=inputs_embeds,
423
- use_cache=use_cache,
424
- output_attentions=output_attentions,
425
- output_hidden_states=output_hidden_states,
426
- return_dict=return_dict,
427
- )
428
-
429
- def llm_forward(
430
- self,
431
- input_ids: torch.LongTensor = None,
432
- token_type_ids: torch.LongTensor = None,
433
- attention_mask: Optional[torch.Tensor] = None,
434
- position_ids: Optional[torch.LongTensor] = None,
435
- past_key_values: Optional[List[torch.FloatTensor]] = None,
436
- inputs_embeds: Optional[torch.FloatTensor] = None,
437
- use_cache: Optional[bool] = None,
438
- output_attentions: Optional[bool] = None,
439
- output_hidden_states: Optional[bool] = None,
440
- return_dict: Optional[bool] = None,
441
- ) -> Union[Tuple, BaseModelOutputWithPast]:
442
- """largely copy from llama forward and adapt for cogvlm with `token_type_ids`"""
443
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
444
- output_hidden_states = (
445
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
446
- )
447
- use_cache = use_cache if use_cache is not None else self.config.use_cache
448
-
449
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
450
-
451
- # retrieve input_ids and inputs_embeds
452
- if input_ids is not None and inputs_embeds is not None:
453
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
454
- elif input_ids is not None:
455
- batch_size, seq_length = input_ids.shape
456
- elif inputs_embeds is not None:
457
- batch_size, seq_length, _ = inputs_embeds.shape
458
- else:
459
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
460
-
461
- seq_length_with_past = seq_length
462
- past_key_values_length = 0
463
-
464
- if past_key_values is not None:
465
- past_key_values_length = past_key_values[0][0].shape[2]
466
- seq_length_with_past = seq_length_with_past + past_key_values_length
467
-
468
- if position_ids is None:
469
- device = input_ids.device if input_ids is not None else inputs_embeds.device
470
- position_ids = torch.arange(
471
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
472
- )
473
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
474
- else:
475
- position_ids = position_ids.view(-1, seq_length).long()
476
-
477
- if inputs_embeds is None:
478
- inputs_embeds = self.embed_tokens(input_ids)
479
- # embed positions
480
- if attention_mask is None:
481
- attention_mask = torch.ones(
482
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
483
- )
484
- attention_mask = self._prepare_decoder_attention_mask(
485
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
486
- )
487
-
488
- hidden_states = inputs_embeds
489
-
490
- # decoder layers
491
- all_hidden_states = () if output_hidden_states else None
492
- all_self_attns = () if output_attentions else None
493
- next_decoder_cache = () if use_cache else None
494
-
495
- for idx, decoder_layer in enumerate(self.layers):
496
- if output_hidden_states:
497
- all_hidden_states += (hidden_states,)
498
-
499
- past_key_value = past_key_values[idx] if past_key_values is not None else None
500
- layer_outputs = decoder_layer(
501
- hidden_states,
502
- token_type_ids=token_type_ids,
503
- attention_mask=attention_mask,
504
- position_ids=position_ids,
505
- past_key_value=past_key_value,
506
- output_attentions=output_attentions,
507
- use_cache=use_cache,
508
- )
509
- hidden_states = layer_outputs[0]
510
-
511
- if use_cache:
512
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
513
-
514
- if output_attentions:
515
- all_self_attns += (layer_outputs[1],)
516
-
517
- hidden_states = self.norm(hidden_states)
518
-
519
- # add hidden states from the last decoder layer
520
- if output_hidden_states:
521
- all_hidden_states += (hidden_states,)
522
-
523
- next_cache = next_decoder_cache if use_cache else None
524
- if not return_dict:
525
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
526
- return BaseModelOutputWithPast(
527
- last_hidden_state=hidden_states,
528
- past_key_values=next_cache,
529
- hidden_states=all_hidden_states,
530
- attentions=all_self_attns,
531
- )
532
-
533
- def get_input_embeddings(self):
534
- return self.embed_tokens
535
-
536
- def set_input_embeddings(self, value):
537
- self.embed_tokens = value
538
-
539
- # noinspection PyMethodMayBeStatic
540
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
541
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
542
- # create causal mask
543
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
544
- combined_attention_mask = None
545
- if input_shape[-1] > 1:
546
- combined_attention_mask = _make_causal_mask(
547
- input_shape,
548
- inputs_embeds.dtype,
549
- device=inputs_embeds.device,
550
- past_key_values_length=past_key_values_length,
551
- )
552
-
553
- if attention_mask is not None:
554
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
555
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
556
- inputs_embeds.device
557
- )
558
- combined_attention_mask = (
559
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
560
- )
561
-
562
- return combined_attention_mask
563
-
564
-
565
- def _history_to_prompt(signal_type, history, query):
566
- if signal_type == 'base':
567
- return query
568
- elif signal_type == 'vqa':
569
- answer_format = 'Short answer:'
570
- elif signal_type == 'chat':
571
- answer_format = 'Answer:'
572
- else:
573
- assert False, f"Unknown signal type {signal_type}"
574
-
575
- prompt = ''
576
- for i, (old_query, response) in enumerate(history):
577
- prompt += 'Question: ' + old_query + " {} ".format(answer_format) + response + "\n"
578
- prompt += 'Question: {} {}'.format(query, answer_format)
579
- return prompt
580
-
581
-
582
- class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
583
- _auto_class = "AutoModelForCausalLM"
584
-
585
- def __init__(self, config):
586
- super().__init__(config)
587
- self.model = CogVLMVideoModel(config)
588
- self.vocab_size = config.vocab_size
589
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
590
- self.video_downsample = 1 # TODO: change this to config
591
-
592
- # Initialize weights and apply final processing
593
- self.post_init()
594
-
595
- def get_input_embeddings(self):
596
- return self.model.embed_tokens
597
-
598
- def set_input_embeddings(self, value):
599
- self.model.embed_tokens = value
600
-
601
- def get_output_embeddings(self):
602
- return self.lm_head
603
-
604
- def set_output_embeddings(self, new_embeddings):
605
- self.lm_head = new_embeddings
606
-
607
- def set_decoder(self, decoder):
608
- self.model = decoder
609
-
610
- def get_decoder(self):
611
- return self.model
612
-
613
- def forward(
614
- self,
615
- input_ids: torch.LongTensor = None,
616
- images: List[List[torch.Tensor]] = None,
617
- token_type_ids: Optional[torch.LongTensor] = None,
618
- attention_mask: Optional[torch.Tensor] = None,
619
- position_ids: Optional[torch.LongTensor] = None,
620
- past_key_values: Optional[List[torch.FloatTensor]] = None,
621
- inputs_embeds: Optional[torch.FloatTensor] = None,
622
- use_cache: Optional[bool] = None,
623
- output_attentions: Optional[bool] = None,
624
- output_hidden_states: Optional[bool] = None,
625
- return_dict: Optional[bool] = None,
626
- labels: Optional[torch.LongTensor] = None,
627
- ) -> Union[Tuple, CausalLMOutputWithPast]:
628
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
629
- output_hidden_states = (
630
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
631
- )
632
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
633
-
634
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
635
- outputs = self.model(
636
- input_ids=input_ids,
637
- images=images,
638
- token_type_ids=token_type_ids,
639
- attention_mask=attention_mask,
640
- position_ids=position_ids,
641
- past_key_values=past_key_values,
642
- inputs_embeds=inputs_embeds,
643
- use_cache=use_cache,
644
- output_attentions=output_attentions,
645
- output_hidden_states=output_hidden_states,
646
- return_dict=return_dict,
647
- )
648
-
649
- hidden_states = outputs[0]
650
- logits = self.lm_head(hidden_states)
651
- logits = logits.float()
652
-
653
- loss = None
654
- if labels is not None:
655
- # Shift so that tokens < n predict n
656
- shift_logits = logits[..., :-1, :].contiguous()
657
- shift_labels = labels[..., 1:].contiguous()
658
- # Flatten the tokens
659
- loss_fct = CrossEntropyLoss()
660
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
661
- shift_labels = shift_labels.view(-1)
662
- # Enable model parallelism
663
- shift_labels = shift_labels.to(shift_logits.device)
664
- loss = loss_fct(shift_logits, shift_labels)
665
-
666
- if not return_dict:
667
- output = (logits,) + outputs[1:]
668
- return (loss,) + output if loss is not None else output
669
-
670
- return CausalLMOutputWithPast(
671
- loss=loss,
672
- logits=logits,
673
- past_key_values=outputs.past_key_values,
674
- hidden_states=outputs.hidden_states,
675
- attentions=outputs.attentions,
676
- )
677
-
678
- def _prepare_attention_mask_for_generation(
679
- self,
680
- inputs: torch.Tensor,
681
- pad_token_id: Optional[int],
682
- eos_token_id: Optional[Union[int, List[int]]],
683
- ) -> torch.LongTensor:
684
- return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
685
-
686
- def prepare_inputs_for_generation(
687
- self, input_ids, token_type_ids, images=None, past_key_values=None, attention_mask=None, inputs_embeds=None,
688
- **kwargs
689
- ):
690
- # build position_ids if needed
691
- position_ids = kwargs.get("position_ids", None)
692
- if position_ids is None:
693
- position_ids = build_position_ids(token_type_ids, attention_mask)
694
-
695
- if past_key_values:
696
- input_ids = input_ids[:, -1:]
697
- token_type_ids = token_type_ids[:, -1:]
698
- position_ids = position_ids[:, -1:]
699
-
700
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
701
- if inputs_embeds is not None and past_key_values is None:
702
- model_inputs = {"inputs_embeds": inputs_embeds}
703
- else:
704
- model_inputs = {"input_ids": input_ids}
705
-
706
- model_inputs.update(
707
- {
708
- "token_type_ids": token_type_ids,
709
- "images": images,
710
- "position_ids": position_ids,
711
- "past_key_values": past_key_values,
712
- "use_cache": kwargs.get("use_cache"),
713
- "attention_mask": attention_mask,
714
- }
715
- )
716
- return model_inputs
717
-
718
- def _update_model_kwargs_for_generation(
719
- self,
720
- outputs: "ModelOutput",
721
- model_kwargs: Dict[str, Any],
722
- is_encoder_decoder: bool = False,
723
- standardize_cache_format: bool = False,
724
- ) -> Dict[str, Any]:
725
- # update past_key_values
726
- cache_name, cache = self._extract_past_from_model_output(
727
- outputs, standardize_cache_format=standardize_cache_format
728
- )
729
- model_kwargs[cache_name] = cache
730
-
731
- if getattr(outputs, "state", None) is not None:
732
- model_kwargs["state"] = outputs.state
733
-
734
- # update token_type_ids with last value
735
- if "token_type_ids" in model_kwargs:
736
- token_type_ids = model_kwargs["token_type_ids"]
737
- new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype,
738
- device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
739
- model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
740
-
741
- if not is_encoder_decoder:
742
- # update attention mask
743
- if "attention_mask" in model_kwargs:
744
- attention_mask = model_kwargs["attention_mask"]
745
- model_kwargs["attention_mask"] = torch.cat(
746
- [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
747
- )
748
- else:
749
- # update decoder attention mask
750
- if "decoder_attention_mask" in model_kwargs:
751
- decoder_attention_mask = model_kwargs["decoder_attention_mask"]
752
- model_kwargs["decoder_attention_mask"] = torch.cat(
753
- [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
754
- dim=-1,
755
- )
756
-
757
- return model_kwargs
758
-
759
- def _reorder_cache(self, past_key_values, beam_idx):
760
- reordered_past = ()
761
- for layer_past in past_key_values:
762
- reordered_past += (
763
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
764
- )
765
- return reordered_past
766
-
767
- def build_conversation_input_ids(
768
- self,
769
- tokenizer: "PreTrainedTokenizer",
770
- *,
771
- query: str,
772
- history: Optional[List[Tuple[str, str]]] = None,
773
- images: Optional[List["PIL.Image"]] = None,
774
- template_version: Optional[Literal["base", "chat", "vqa"]] = None,
775
- answer: str = None,
776
- ):
777
- image_size: int = self.config.vision_config['image_size']
778
- template_version = template_version or self.config.template_version
779
- assert images is None or len(images) <= 1, f"not support multi images by now."
780
- history = history or []
781
- text = _history_to_prompt(template_version, history, query)
782
- input_ids = [tokenizer.bos_token_id]
783
- token_type_ids = [LANGUAGE_TOKEN_TYPE]
784
- add_time_indices = True if template_version == 'chat' else False
785
- if images is not None and len(images) == 1:
786
- # vision
787
- transform = transforms.Compose(
788
- [
789
- # UniformTemporalSubsample(num_frames),
790
- Lambda(lambda x: x / 255.0),
791
- NormalizeVideo(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
792
- ShortSideScale(size=image_size),
793
- CenterCropVideo(image_size),
794
- # RandomHorizontalFlipVideo(p=0.5),
795
- ]
796
- )
797
- images = [transform(images[0]).transpose(0, 1)] # (T, C, H, W)
798
- num_eois = len(images[0])
799
- tokenizer.pad_token_id = 128002
800
- if not add_time_indices:
801
- vision_token_num = (64 + 2) * num_eois
802
- input_ids += [tokenizer.pad_token_id] * vision_token_num # add spetial token
803
- token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
804
- else:
805
- video_ids, video_type_ids = [], []
806
- sing_vision_token_num = (64 + 2)
807
- for _time_idx in range(num_eois):
808
- video_ids += [tokenizer.pad_token_id] * sing_vision_token_num
809
- video_type_ids += [VISION_TOKEN_TYPE] * sing_vision_token_num
810
- # add time indices
811
- time_indices = tokenizer.encode(str(_time_idx), add_special_tokens=False)
812
- video_ids += time_indices
813
- video_type_ids += [LANGUAGE_TOKEN_TYPE] * len(time_indices)
814
- # llama3 adapt for cogvlm
815
- input_ids += video_ids
816
- token_type_ids += video_type_ids
817
-
818
- text_ids = tokenizer.encode(text, add_special_tokens=False)
819
-
820
- if answer is not None:
821
- answer_ids = tokenizer.encode(answer, add_special_tokens=False)
822
- answer_ids += [tokenizer.eos_token_id]
823
- text_ids += answer_ids
824
-
825
- input_ids += text_ids
826
- token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
827
- attention_mask = [1] * len(input_ids)
828
- if answer is not None:
829
- labels = [-100 for _ in range(len(input_ids) - len(answer_ids))] + answer_ids
830
- labels = torch.tensor(labels, dtype=torch.long)
831
- else:
832
- labels = None
833
-
834
- return {
835
- 'input_ids': torch.tensor(input_ids, dtype=torch.long),
836
- 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
837
- 'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
838
- 'images': images,
839
- 'labels': labels,
840
- }
 
1
+ """largely copy from llama and adapt for cogvlm"""
2
+ import warnings
3
+ from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any
4
+
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import CrossEntropyLoss
9
+ from torchvision import transforms
10
+ from einops import rearrange
11
+ from transformers import PreTrainedModel, PreTrainedTokenizer
12
+ from transformers.utils.logging import get_logger
13
+ from transformers.activations import ACT2FN
14
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
15
+ from torchvision.transforms import Lambda
16
+ from torchvision.transforms._transforms_video import NormalizeVideo, CenterCropVideo
17
+ from pytorchvideo.transforms import ShortSideScale
18
+ from .configuration_cogvlm import CogVLMConfig
19
+ from .util import FastRotaryEmbedding
20
+ from .visual import EVA2CLIPModel
21
+
22
+ if TYPE_CHECKING:
23
+ from transformers.utils import ModelOutput
24
+
25
+ logger = get_logger(__name__)
26
+
27
+ LANGUAGE_TOKEN_TYPE = 0
28
+ VISION_TOKEN_TYPE = 1
29
+
30
+
31
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
32
+ def _make_causal_mask(
33
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
34
+ ):
35
+ """
36
+ Make causal mask used for bi-directional self-attention.
37
+ """
38
+ bsz, tgt_len = input_ids_shape
39
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
40
+ mask_cond = torch.arange(mask.size(-1), device=device)
41
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
42
+ mask = mask.to(dtype)
43
+
44
+ if past_key_values_length > 0:
45
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
46
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
47
+
48
+
49
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
50
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
51
+ """
52
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
53
+ """
54
+ bsz, src_len = mask.size()
55
+ tgt_len = tgt_len if tgt_len is not None else src_len
56
+
57
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
58
+
59
+ inverted_mask = 1.0 - expanded_mask
60
+
61
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
62
+
63
+
64
+ class RMSNorm(nn.Module):
65
+ def __init__(self, hidden_size, eps=1e-5):
66
+ super().__init__()
67
+ self.weight = nn.Parameter(torch.ones(hidden_size))
68
+ self.variance_epsilon = eps
69
+
70
+ def forward(self, hidden_states):
71
+ input_dtype = hidden_states.dtype
72
+ hidden_states = hidden_states.to(torch.float32)
73
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
74
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
75
+ return (self.weight * hidden_states).to(input_dtype)
76
+
77
+
78
+ class MLP(nn.Module):
79
+ def __init__(self, config):
80
+ super().__init__()
81
+ self.hidden_size = config.hidden_size
82
+ self.intermediate_size = config.intermediate_size
83
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
84
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
85
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
86
+ self.act_fn = ACT2FN[config.hidden_act]
87
+
88
+ def forward(self, x):
89
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
90
+ return down_proj
91
+
92
+
93
+ def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
94
+ vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
95
+ vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (
96
+ token_type_ids[:, 1:] == VISION_TOKEN_TYPE)
97
+ language_token_mask = ~vision_token_mask
98
+ return vision_token_mask, language_token_mask
99
+
100
+
101
+ class VisionExpertMLP(nn.Module):
102
+ def __init__(self, config):
103
+ super().__init__()
104
+ self.language_mlp = MLP(config)
105
+ # self.vision_mlp = MLP(config)
106
+
107
+ def forward(self, hidden_states: "torch.Tensor(B, L, D)", token_type_ids: "torch.LongTensor(B, L)"):
108
+ # output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device)
109
+ # vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
110
+ # output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
111
+ # output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
112
+
113
+ output = self.language_mlp(hidden_states)
114
+ return output
115
+
116
+
117
+ def attention_fn(
118
+ query_layer: "torch.tensor(B, H, L, HD)",
119
+ key_layer: "torch.tensor(B, H, L, HD)",
120
+ value_layer: "torch.tensor(B, H, L, HD)",
121
+ attention_mask: "torch.tensor(B, H, L, HD)",
122
+ *,
123
+ scaling_attention_score: bool = True,
124
+ attention_dropout: nn.Module = None
125
+ ):
126
+ attention_mask_bool = (attention_mask == 0)
127
+ is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
128
+ is_full = (attention_mask_bool > 0).all()
129
+ if not (int(torch.__version__.split('.')[0]) >= 2):
130
+ warnings.warn("It's recommended to use torch2.0 or higher.")
131
+ if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
132
+ dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
133
+ return torch.nn.functional.scaled_dot_product_attention(
134
+ query_layer, key_layer, value_layer,
135
+ attn_mask=None,
136
+ dropout_p=dropout_p,
137
+ is_causal=not is_full
138
+ )
139
+ else:
140
+ if scaling_attention_score:
141
+ query_layer = query_layer / math.sqrt(query_layer.shape[-1])
142
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
143
+ attention_scores = attention_scores + attention_mask
144
+ attention_scores = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
145
+ if attention_dropout is not None:
146
+ attention_scores = attention_dropout(attention_scores)
147
+ context_layer = torch.matmul(attention_scores, value_layer)
148
+ return context_layer
149
+
150
+
151
+ class VisionExpertAttention(nn.Module):
152
+ def __init__(self, config):
153
+ super().__init__()
154
+ self.config = config
155
+ self.hidden_size = config.hidden_size
156
+ self.num_attention_heads = config.num_attention_heads
157
+ self.num_multi_query_heads = config.num_multi_query_heads
158
+ self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
159
+ self.stride = [self.num_attention_heads, self.num_multi_query_heads, self.num_multi_query_heads]
160
+ self.qkv_size = self.hidden_size + self.hidden_size_per_attention_head * self.num_multi_query_heads * 2
161
+ self.head_dim = self.hidden_size // self.num_attention_heads
162
+ self.max_position_embeddings = config.max_position_embeddings
163
+ self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False, base=500000)
164
+ # self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.qkv_size, bias=True)
165
+ # self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
166
+ self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.qkv_size, bias=False)
167
+ self.language_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
168
+
169
+ def _transpose_for_scores(self, tensor):
170
+ """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
171
+ new_tensor_shape = tensor.size()[:-1] + \
172
+ (-1, # flexible for multi-query
173
+ self.hidden_size_per_attention_head)
174
+ tensor = tensor.view(*new_tensor_shape)
175
+ return tensor.permute(0, 2, 1, 3)
176
+
177
+ def forward(
178
+ self,
179
+ hidden_states: torch.Tensor,
180
+ token_type_ids: torch.LongTensor,
181
+ position_ids: torch.LongTensor,
182
+ attention_mask: Optional[torch.Tensor] = None,
183
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
184
+ output_attentions: bool = False,
185
+ use_cache: bool = False,
186
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
187
+ bsz, q_len, _ = hidden_states.size()
188
+ # vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
189
+
190
+ shape = list(hidden_states.shape)
191
+ shape[-1] = self.qkv_size
192
+ # mixed_raw_layer = torch.empty(shape, dtype=hidden_states.dtype, device=hidden_states.device)
193
+ # mixed_raw_layer[vision_token_mask] = self.vision_expert_query_key_value(hidden_states[vision_token_mask])
194
+ # mixed_raw_layer[language_token_mask] = self.language_expert_query_key_value(hidden_states[language_token_mask])
195
+ mixed_raw_layer = self.language_expert_query_key_value(hidden_states)
196
+
197
+ # query_states, key_states, value_states = torch.split(mixed_raw_layer, self.hidden_size, dim=-1)
198
+ factor = mixed_raw_layer.size()[-1] // sum(self.stride)
199
+ query_states, key_states, value_states = torch.split(mixed_raw_layer, [factor * x for x in self.stride], dim=-1)
200
+
201
+ query_states = self._transpose_for_scores(query_states) # B, H, L, HD
202
+ key_states = self._transpose_for_scores(key_states) # B, H, L, HD
203
+ value_states = self._transpose_for_scores(value_states) # B, H, L, HD
204
+
205
+ kv_seq_len = key_states.shape[-2]
206
+ if past_key_value is not None:
207
+ kv_seq_len += past_key_value[0].shape[-2]
208
+
209
+ query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids,
210
+ max_seqlen=position_ids.max() + 1)
211
+
212
+ if past_key_value is not None:
213
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
214
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
215
+
216
+ past_key_value = (key_states, value_states) if use_cache else None
217
+
218
+ key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads, -1,
219
+ -1).contiguous().view(
220
+ bsz, self.num_attention_heads, *key_states.shape[2:])
221
+ value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads,
222
+ -1,
223
+ -1).contiguous().view(bsz, self.num_attention_heads,
224
+ *value_states.shape[2:])
225
+
226
+ context_layer = attention_fn(
227
+ query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
228
+ scaling_attention_score=True, attention_dropout=None)
229
+ if context_layer.size() != (bsz, self.num_attention_heads, q_len, self.head_dim):
230
+ raise ValueError(
231
+ f"`attn_output` should be of size {(bsz, self.num_attention_heads, q_len, self.head_dim)}, but is"
232
+ f" {context_layer.size()}"
233
+ )
234
+ context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
235
+
236
+ # attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
237
+ # attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
238
+ # attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
239
+
240
+ attn_output = self.language_expert_dense(context_layer)
241
+
242
+ if output_attentions:
243
+ warnings.warn("output_attentions is not implemented.")
244
+
245
+ return attn_output, None, past_key_value
246
+
247
+
248
+ class CogVLMDecoderLayer(nn.Module):
249
+ def __init__(self, config):
250
+ super().__init__()
251
+ self.hidden_size = config.hidden_size
252
+ self.self_attn = VisionExpertAttention(config=config)
253
+ self.mlp = VisionExpertMLP(config)
254
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
255
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
256
+
257
+ def forward(
258
+ self,
259
+ hidden_states: torch.Tensor,
260
+ token_type_ids: torch.LongTensor,
261
+ position_ids: torch.LongTensor,
262
+ attention_mask: Optional[torch.Tensor] = None,
263
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
264
+ output_attentions: Optional[bool] = False,
265
+ use_cache: Optional[bool] = False,
266
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
267
+ residual = hidden_states
268
+
269
+ hidden_states = self.input_layernorm(hidden_states)
270
+
271
+ # Self Attention
272
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
273
+ hidden_states=hidden_states,
274
+ token_type_ids=token_type_ids,
275
+ position_ids=position_ids,
276
+ attention_mask=attention_mask,
277
+ past_key_value=past_key_value,
278
+ output_attentions=output_attentions,
279
+ use_cache=use_cache,
280
+ )
281
+ hidden_states = residual + hidden_states
282
+
283
+ # Fully Connected
284
+ residual = hidden_states
285
+ hidden_states = self.post_attention_layernorm(hidden_states)
286
+ hidden_states = self.mlp(hidden_states, token_type_ids=token_type_ids)
287
+ hidden_states = residual + hidden_states
288
+
289
+ outputs = (hidden_states,)
290
+
291
+ if output_attentions:
292
+ outputs += (self_attn_weights,)
293
+
294
+ if use_cache:
295
+ outputs += (present_key_value,)
296
+
297
+ return outputs # type: ignore
298
+
299
+
300
+ class CogVLMPreTrainedModel(PreTrainedModel):
301
+ config_class = CogVLMConfig
302
+ base_model_prefix = "model"
303
+ supports_gradient_checkpointing = False
304
+ _no_split_modules = ["CogVLMDecoderLayer"]
305
+ _skip_keys_device_placement = "past_key_values"
306
+
307
+ def _init_weights(self, module):
308
+ std = self.config.initializer_range
309
+ if isinstance(module, nn.Linear):
310
+ module.weight.data.normal_(mean=0.0, std=std)
311
+ if module.bias is not None:
312
+ module.bias.data.zero_()
313
+ elif isinstance(module, nn.Embedding):
314
+ module.weight.data.normal_(mean=0.0, std=std)
315
+ if module.padding_idx is not None:
316
+ module.weight.data[module.padding_idx].zero_()
317
+
318
+
319
+ def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
320
+ if images_list is None or len(images_list) == 0:
321
+ return True
322
+ for image_list in images_list:
323
+ if len(image_list):
324
+ return False
325
+ return True
326
+
327
+
328
+ def build_position_ids(x: "torch.BoolTensor(B, L)",
329
+ attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)":
330
+ if attention_mask is not None:
331
+ tmp = x.clone()
332
+ tmp[~(attention_mask.bool())] = -1
333
+ else:
334
+ tmp = x.clone()
335
+ # image boi eoi token as LANGUAGE_TOKEN_TYPE
336
+ is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
337
+ is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
338
+ is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
339
+ is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
340
+ is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
341
+ tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
342
+ # final position ids
343
+ y = torch.zeros_like(x, dtype=torch.long)
344
+ y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
345
+ (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
346
+ y = y.cumsum(dim=-1)
347
+ return y
348
+
349
+
350
+ class CogVLMVideoModel(CogVLMPreTrainedModel):
351
+ def __init__(self, config):
352
+ super().__init__(config)
353
+ self.padding_idx = 128002
354
+ self.vocab_size = config.vocab_size
355
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
356
+ self.layers = nn.ModuleList([CogVLMDecoderLayer(config) for _ in range(config.num_hidden_layers)])
357
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
358
+
359
+ self.vision = EVA2CLIPModel(config)
360
+
361
+ self.gradient_checkpointing = False
362
+ # Initialize weights and apply final processing
363
+ self.post_init()
364
+
365
+ def encode_images(self, images: List[List[torch.Tensor]], ) -> torch.Tensor:
366
+ images_list, images = images, []
367
+
368
+ images = []
369
+ for image_list in images_list:
370
+ for image in image_list:
371
+ images.append(image)
372
+
373
+ # images = torch.stack(images) # video images is already stacked
374
+ images_features = self.vision(images[0])
375
+ return images_features
376
+
377
+ def forward(
378
+ self,
379
+ input_ids: torch.LongTensor = None,
380
+ images: List[List[torch.Tensor]] = None,
381
+ token_type_ids: Optional[torch.LongTensor] = None,
382
+ attention_mask: Optional[torch.Tensor] = None,
383
+ position_ids: Optional[torch.LongTensor] = None,
384
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
385
+ inputs_embeds: Optional[torch.FloatTensor] = None,
386
+ use_cache: Optional[bool] = None,
387
+ output_attentions: Optional[bool] = None,
388
+ output_hidden_states: Optional[bool] = None,
389
+ return_dict: Optional[bool] = None,
390
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
391
+ """take care of image_encode, token_type_ids, position_ids and (attention_mask = None is fine)"""
392
+
393
+ if past_key_values is not None:
394
+ pass # generate mode with past_key_values. the image features are already mapped
395
+ else:
396
+ # not allow for inputs_embeds, because we want to process image feature
397
+ assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
398
+ if not is_empty(images): # multi-modality
399
+ assert token_type_ids is not None, f"multi-modality requires `token_type_ids`!"
400
+ assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
401
+ inputs_embeds = self.embed_tokens(input_ids)
402
+ images_features = self.encode_images(images)
403
+ images_features = rearrange(images_features, 'b n d -> (b n) d')
404
+ images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
405
+ inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
406
+ else: # single-modality
407
+ if token_type_ids is None:
408
+ token_type_ids = torch.ones_like(input_ids, dtype=torch.long,
409
+ device=input_ids.device) * LANGUAGE_TOKEN_TYPE
410
+ assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
411
+ inputs_embeds = self.embed_tokens(input_ids)
412
+
413
+ if position_ids is None:
414
+ position_ids = build_position_ids(token_type_ids, attention_mask)
415
+ input_ids = None
416
+ return self.llm_forward(
417
+ input_ids=input_ids,
418
+ token_type_ids=token_type_ids,
419
+ attention_mask=attention_mask,
420
+ position_ids=position_ids,
421
+ past_key_values=past_key_values,
422
+ inputs_embeds=inputs_embeds,
423
+ use_cache=use_cache,
424
+ output_attentions=output_attentions,
425
+ output_hidden_states=output_hidden_states,
426
+ return_dict=return_dict,
427
+ )
428
+
429
+ def llm_forward(
430
+ self,
431
+ input_ids: torch.LongTensor = None,
432
+ token_type_ids: torch.LongTensor = None,
433
+ attention_mask: Optional[torch.Tensor] = None,
434
+ position_ids: Optional[torch.LongTensor] = None,
435
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
436
+ inputs_embeds: Optional[torch.FloatTensor] = None,
437
+ use_cache: Optional[bool] = None,
438
+ output_attentions: Optional[bool] = None,
439
+ output_hidden_states: Optional[bool] = None,
440
+ return_dict: Optional[bool] = None,
441
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
442
+ """largely copy from llama forward and adapt for cogvlm with `token_type_ids`"""
443
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
444
+ output_hidden_states = (
445
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
446
+ )
447
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
448
+
449
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
450
+
451
+ # retrieve input_ids and inputs_embeds
452
+ if input_ids is not None and inputs_embeds is not None:
453
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
454
+ elif input_ids is not None:
455
+ batch_size, seq_length = input_ids.shape
456
+ elif inputs_embeds is not None:
457
+ batch_size, seq_length, _ = inputs_embeds.shape
458
+ else:
459
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
460
+
461
+ seq_length_with_past = seq_length
462
+ past_key_values_length = 0
463
+
464
+ if past_key_values is not None:
465
+ past_key_values_length = past_key_values[0][0].shape[2]
466
+ seq_length_with_past = seq_length_with_past + past_key_values_length
467
+
468
+ if position_ids is None:
469
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
470
+ position_ids = torch.arange(
471
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
472
+ )
473
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
474
+ else:
475
+ position_ids = position_ids.view(-1, seq_length).long()
476
+
477
+ if inputs_embeds is None:
478
+ inputs_embeds = self.embed_tokens(input_ids)
479
+ # embed positions
480
+ if attention_mask is None:
481
+ attention_mask = torch.ones(
482
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
483
+ )
484
+ attention_mask = self._prepare_decoder_attention_mask(
485
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
486
+ )
487
+
488
+ hidden_states = inputs_embeds
489
+
490
+ # decoder layers
491
+ all_hidden_states = () if output_hidden_states else None
492
+ all_self_attns = () if output_attentions else None
493
+ next_decoder_cache = () if use_cache else None
494
+
495
+ for idx, decoder_layer in enumerate(self.layers):
496
+ if output_hidden_states:
497
+ all_hidden_states += (hidden_states,)
498
+
499
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
500
+ layer_outputs = decoder_layer(
501
+ hidden_states,
502
+ token_type_ids=token_type_ids,
503
+ attention_mask=attention_mask,
504
+ position_ids=position_ids,
505
+ past_key_value=past_key_value,
506
+ output_attentions=output_attentions,
507
+ use_cache=use_cache,
508
+ )
509
+ hidden_states = layer_outputs[0]
510
+
511
+ if use_cache:
512
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
513
+
514
+ if output_attentions:
515
+ all_self_attns += (layer_outputs[1],)
516
+
517
+ hidden_states = self.norm(hidden_states)
518
+
519
+ # add hidden states from the last decoder layer
520
+ if output_hidden_states:
521
+ all_hidden_states += (hidden_states,)
522
+
523
+ next_cache = next_decoder_cache if use_cache else None
524
+ if not return_dict:
525
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
526
+ return BaseModelOutputWithPast(
527
+ last_hidden_state=hidden_states,
528
+ past_key_values=next_cache,
529
+ hidden_states=all_hidden_states,
530
+ attentions=all_self_attns,
531
+ )
532
+
533
+ def get_input_embeddings(self):
534
+ return self.embed_tokens
535
+
536
+ def set_input_embeddings(self, value):
537
+ self.embed_tokens = value
538
+
539
+ # noinspection PyMethodMayBeStatic
540
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
541
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
542
+ # create causal mask
543
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
544
+ combined_attention_mask = None
545
+ if input_shape[-1] > 1:
546
+ combined_attention_mask = _make_causal_mask(
547
+ input_shape,
548
+ inputs_embeds.dtype,
549
+ device=inputs_embeds.device,
550
+ past_key_values_length=past_key_values_length,
551
+ )
552
+
553
+ if attention_mask is not None:
554
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
555
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
556
+ inputs_embeds.device
557
+ )
558
+ combined_attention_mask = (
559
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
560
+ )
561
+
562
+ return combined_attention_mask
563
+
564
+
565
+ def _history_to_prompt(signal_type, history, query):
566
+ if signal_type == 'base':
567
+ return query
568
+ elif signal_type == 'vqa':
569
+ answer_format = 'Short answer:'
570
+ elif signal_type == 'chat':
571
+ answer_format = 'Answer:'
572
+ else:
573
+ assert False, f"Unknown signal type {signal_type}"
574
+
575
+ prompt = ''
576
+ for i, (old_query, response) in enumerate(history):
577
+ prompt += 'Question: ' + old_query + " {} ".format(answer_format) + response + "\n"
578
+ prompt += 'Question: {} {}'.format(query, answer_format)
579
+ return prompt
580
+
581
+
582
+ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
583
+ _auto_class = "AutoModelForCausalLM"
584
+
585
+ def __init__(self, config):
586
+ super().__init__(config)
587
+ self.model = CogVLMVideoModel(config)
588
+ self.vocab_size = config.vocab_size
589
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
590
+ self.video_downsample = 1 # TODO: change this to config
591
+
592
+ # Initialize weights and apply final processing
593
+ self.post_init()
594
+
595
+ def get_input_embeddings(self):
596
+ return self.model.embed_tokens
597
+
598
+ def set_input_embeddings(self, value):
599
+ self.model.embed_tokens = value
600
+
601
+ def get_output_embeddings(self):
602
+ return self.lm_head
603
+
604
+ def set_output_embeddings(self, new_embeddings):
605
+ self.lm_head = new_embeddings
606
+
607
+ def set_decoder(self, decoder):
608
+ self.model = decoder
609
+
610
+ def get_decoder(self):
611
+ return self.model
612
+
613
+ def forward(
614
+ self,
615
+ input_ids: torch.LongTensor = None,
616
+ images: List[List[torch.Tensor]] = None,
617
+ token_type_ids: Optional[torch.LongTensor] = None,
618
+ attention_mask: Optional[torch.Tensor] = None,
619
+ position_ids: Optional[torch.LongTensor] = None,
620
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
621
+ inputs_embeds: Optional[torch.FloatTensor] = None,
622
+ use_cache: Optional[bool] = None,
623
+ output_attentions: Optional[bool] = None,
624
+ output_hidden_states: Optional[bool] = None,
625
+ return_dict: Optional[bool] = None,
626
+ labels: Optional[torch.LongTensor] = None,
627
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
628
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
629
+ output_hidden_states = (
630
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
631
+ )
632
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
633
+
634
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
635
+ outputs = self.model(
636
+ input_ids=input_ids,
637
+ images=images,
638
+ token_type_ids=token_type_ids,
639
+ attention_mask=attention_mask,
640
+ position_ids=position_ids,
641
+ past_key_values=past_key_values,
642
+ inputs_embeds=inputs_embeds,
643
+ use_cache=use_cache,
644
+ output_attentions=output_attentions,
645
+ output_hidden_states=output_hidden_states,
646
+ return_dict=return_dict,
647
+ )
648
+
649
+ hidden_states = outputs[0]
650
+ logits = self.lm_head(hidden_states)
651
+ logits = logits.float()
652
+
653
+ loss = None
654
+ if labels is not None:
655
+ # Shift so that tokens < n predict n
656
+ shift_logits = logits[..., :-1, :].contiguous()
657
+ shift_labels = labels[..., 1:].contiguous()
658
+ # Flatten the tokens
659
+ loss_fct = CrossEntropyLoss()
660
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
661
+ shift_labels = shift_labels.view(-1)
662
+ # Enable model parallelism
663
+ shift_labels = shift_labels.to(shift_logits.device)
664
+ loss = loss_fct(shift_logits, shift_labels)
665
+
666
+ if not return_dict:
667
+ output = (logits,) + outputs[1:]
668
+ return (loss,) + output if loss is not None else output
669
+
670
+ return CausalLMOutputWithPast(
671
+ loss=loss,
672
+ logits=logits,
673
+ past_key_values=outputs.past_key_values,
674
+ hidden_states=outputs.hidden_states,
675
+ attentions=outputs.attentions,
676
+ )
677
+
678
+ def _prepare_attention_mask_for_generation(
679
+ self,
680
+ inputs: torch.Tensor,
681
+ pad_token_id: Optional[int],
682
+ eos_token_id: Optional[Union[int, List[int]]],
683
+ ) -> torch.LongTensor:
684
+ return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
685
+
686
+ def prepare_inputs_for_generation(
687
+ self, input_ids, token_type_ids, images=None, past_key_values=None, attention_mask=None, inputs_embeds=None,
688
+ **kwargs
689
+ ):
690
+ # build position_ids if needed
691
+ position_ids = kwargs.get("position_ids", None)
692
+ if position_ids is None:
693
+ position_ids = build_position_ids(token_type_ids, attention_mask)
694
+
695
+ if past_key_values:
696
+ input_ids = input_ids[:, -1:]
697
+ token_type_ids = token_type_ids[:, -1:]
698
+ position_ids = position_ids[:, -1:]
699
+
700
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
701
+ if inputs_embeds is not None and past_key_values is None:
702
+ model_inputs = {"inputs_embeds": inputs_embeds}
703
+ else:
704
+ model_inputs = {"input_ids": input_ids}
705
+
706
+ model_inputs.update(
707
+ {
708
+ "token_type_ids": token_type_ids,
709
+ "images": images,
710
+ "position_ids": position_ids,
711
+ "past_key_values": past_key_values,
712
+ "use_cache": kwargs.get("use_cache"),
713
+ "attention_mask": attention_mask,
714
+ }
715
+ )
716
+ return model_inputs
717
+
718
+ def _update_model_kwargs_for_generation(
719
+ self,
720
+ outputs: "ModelOutput",
721
+ model_kwargs: Dict[str, Any],
722
+ is_encoder_decoder: bool = False,
723
+ standardize_cache_format: bool = False,
724
+ ) -> Dict[str, Any]:
725
+ try:
726
+ cache_name, cache = super()._extract_past_from_model_output(outputs)
727
+ except AttributeError:
728
+ past_key_values = outputs.past_key_values if "past_key_values" in outputs else None
729
+ cache_name, cache = "past_key_values", past_key_values
730
+
731
+ if getattr(outputs, "state", None) is not None:
732
+ model_kwargs["state"] = outputs.state
733
+
734
+ # update token_type_ids with last value
735
+ if "token_type_ids" in model_kwargs:
736
+ token_type_ids = model_kwargs["token_type_ids"]
737
+ new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype,
738
+ device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
739
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
740
+
741
+ if not is_encoder_decoder:
742
+ # update attention mask
743
+ if "attention_mask" in model_kwargs:
744
+ attention_mask = model_kwargs["attention_mask"]
745
+ model_kwargs["attention_mask"] = torch.cat(
746
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
747
+ )
748
+ else:
749
+ # update decoder attention mask
750
+ if "decoder_attention_mask" in model_kwargs:
751
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
752
+ model_kwargs["decoder_attention_mask"] = torch.cat(
753
+ [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
754
+ dim=-1,
755
+ )
756
+
757
+ return model_kwargs
758
+
759
+ def _reorder_cache(self, past_key_values, beam_idx):
760
+ reordered_past = ()
761
+ for layer_past in past_key_values:
762
+ reordered_past += (
763
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
764
+ )
765
+ return reordered_past
766
+
767
+ def build_conversation_input_ids(
768
+ self,
769
+ tokenizer: "PreTrainedTokenizer",
770
+ *,
771
+ query: str,
772
+ history: Optional[List[Tuple[str, str]]] = None,
773
+ images: Optional[List["PIL.Image"]] = None,
774
+ template_version: Optional[Literal["base", "chat", "vqa"]] = None,
775
+ answer: str = None,
776
+ ):
777
+ image_size: int = self.config.vision_config['image_size']
778
+ template_version = template_version or self.config.template_version
779
+ assert images is None or len(images) <= 1, f"not support multi images by now."
780
+ history = history or []
781
+ text = _history_to_prompt(template_version, history, query)
782
+ input_ids = [tokenizer.bos_token_id]
783
+ token_type_ids = [LANGUAGE_TOKEN_TYPE]
784
+ add_time_indices = True if template_version == 'chat' else False
785
+ if images is not None and len(images) == 1:
786
+ # vision
787
+ transform = transforms.Compose(
788
+ [
789
+ # UniformTemporalSubsample(num_frames),
790
+ Lambda(lambda x: x / 255.0),
791
+ NormalizeVideo(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
792
+ ShortSideScale(size=image_size),
793
+ CenterCropVideo(image_size),
794
+ # RandomHorizontalFlipVideo(p=0.5),
795
+ ]
796
+ )
797
+ images = [transform(images[0]).transpose(0, 1)] # (T, C, H, W)
798
+ num_eois = len(images[0])
799
+ tokenizer.pad_token_id = 128002
800
+ if not add_time_indices:
801
+ vision_token_num = (64 + 2) * num_eois
802
+ input_ids += [tokenizer.pad_token_id] * vision_token_num # add spetial token
803
+ token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
804
+ else:
805
+ video_ids, video_type_ids = [], []
806
+ sing_vision_token_num = (64 + 2)
807
+ for _time_idx in range(num_eois):
808
+ video_ids += [tokenizer.pad_token_id] * sing_vision_token_num
809
+ video_type_ids += [VISION_TOKEN_TYPE] * sing_vision_token_num
810
+ # add time indices
811
+ time_indices = tokenizer.encode(str(_time_idx), add_special_tokens=False)
812
+ video_ids += time_indices
813
+ video_type_ids += [LANGUAGE_TOKEN_TYPE] * len(time_indices)
814
+ # llama3 adapt for cogvlm
815
+ input_ids += video_ids
816
+ token_type_ids += video_type_ids
817
+
818
+ text_ids = tokenizer.encode(text, add_special_tokens=False)
819
+
820
+ if answer is not None:
821
+ answer_ids = tokenizer.encode(answer, add_special_tokens=False)
822
+ answer_ids += [tokenizer.eos_token_id]
823
+ text_ids += answer_ids
824
+
825
+ input_ids += text_ids
826
+ token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
827
+ attention_mask = [1] * len(input_ids)
828
+ if answer is not None:
829
+ labels = [-100 for _ in range(len(input_ids) - len(answer_ids))] + answer_ids
830
+ labels = torch.tensor(labels, dtype=torch.long)
831
+ else:
832
+ labels = None
833
+
834
+ return {
835
+ 'input_ids': torch.tensor(input_ids, dtype=torch.long),
836
+ 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
837
+ 'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
838
+ 'images': images,
839
+ 'labels': labels,
840
+ }