khang119966 commited on
Commit
e306fad
·
verified ·
1 Parent(s): bb9dec3

Upload 3 files

Browse files
Files changed (3) hide show
  1. conversation.py +395 -0
  2. modeling_intern_vit.py +430 -0
  3. modeling_internvl_chat.py +543 -0
conversation.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Conversation prompt templates.
3
+
4
+ We kindly request that you import fastchat instead of copying this file if you wish to use it.
5
+ If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
6
+
7
+ Modified from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
8
+ """
9
+
10
+ import dataclasses
11
+ from enum import IntEnum, auto
12
+ from typing import Dict, List, Tuple, Union
13
+
14
+
15
+ class SeparatorStyle(IntEnum):
16
+ """Separator styles."""
17
+
18
+ ADD_COLON_SINGLE = auto()
19
+ ADD_COLON_TWO = auto()
20
+ ADD_COLON_SPACE_SINGLE = auto()
21
+ NO_COLON_SINGLE = auto()
22
+ NO_COLON_TWO = auto()
23
+ ADD_NEW_LINE_SINGLE = auto()
24
+ LLAMA2 = auto()
25
+ CHATGLM = auto()
26
+ CHATML = auto()
27
+ CHATINTERN = auto()
28
+ DOLLY = auto()
29
+ RWKV = auto()
30
+ PHOENIX = auto()
31
+ ROBIN = auto()
32
+ FALCON_CHAT = auto()
33
+ CHATGLM3 = auto()
34
+ INTERNVL_ZH = auto()
35
+ MPT = auto()
36
+
37
+
38
+ @dataclasses.dataclass
39
+ class Conversation:
40
+ """A class that manages prompt templates and keeps all conversation history."""
41
+
42
+ # The name of this template
43
+ name: str
44
+ # The template of the system prompt
45
+ system_template: str = '{system_message}'
46
+ # The system message
47
+ system_message: str = ''
48
+ # The names of two roles
49
+ roles: Tuple[str] = ('USER', 'ASSISTANT')
50
+ # All messages. Each item is (role, message).
51
+ messages: List[List[str]] = ()
52
+ # The number of few shot examples
53
+ offset: int = 0
54
+ # The separator style and configurations
55
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
56
+ sep: str = '\n'
57
+ sep2: str = None
58
+ # Stop criteria (the default one is EOS token)
59
+ stop_str: Union[str, List[str]] = None
60
+ # Stops generation if meeting any token in this list
61
+ stop_token_ids: List[int] = None
62
+
63
+ def get_prompt(self) -> str:
64
+ """Get the prompt for generation."""
65
+ system_prompt = self.system_template.format(system_message=self.system_message)
66
+ if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
67
+ ret = system_prompt + self.sep
68
+ for role, message in self.messages:
69
+ if message:
70
+ ret += role + ': ' + message + self.sep
71
+ else:
72
+ ret += role + ':'
73
+ return ret
74
+ elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
75
+ seps = [self.sep, self.sep2]
76
+ ret = system_prompt + seps[0]
77
+ for i, (role, message) in enumerate(self.messages):
78
+ if message:
79
+ ret += role + ': ' + message + seps[i % 2]
80
+ else:
81
+ ret += role + ':'
82
+ return ret
83
+ elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
84
+ ret = system_prompt + self.sep
85
+ for role, message in self.messages:
86
+ if message:
87
+ ret += role + ': ' + message + self.sep
88
+ else:
89
+ ret += role + ': ' # must be end with a space
90
+ return ret
91
+ elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
92
+ ret = '' if system_prompt == '' else system_prompt + self.sep
93
+ for role, message in self.messages:
94
+ if message:
95
+ ret += role + '\n' + message + self.sep
96
+ else:
97
+ ret += role + '\n'
98
+ return ret
99
+ elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
100
+ ret = system_prompt
101
+ for role, message in self.messages:
102
+ if message:
103
+ ret += role + message + self.sep
104
+ else:
105
+ ret += role
106
+ return ret
107
+ elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
108
+ seps = [self.sep, self.sep2]
109
+ ret = system_prompt
110
+ for i, (role, message) in enumerate(self.messages):
111
+ if message:
112
+ ret += role + message + seps[i % 2]
113
+ else:
114
+ ret += role
115
+ return ret
116
+ elif self.sep_style == SeparatorStyle.RWKV:
117
+ ret = system_prompt
118
+ for i, (role, message) in enumerate(self.messages):
119
+ if message:
120
+ ret += (
121
+ role
122
+ + ': '
123
+ + message.replace('\r\n', '\n').replace('\n\n', '\n')
124
+ )
125
+ ret += '\n\n'
126
+ else:
127
+ ret += role + ':'
128
+ return ret
129
+ elif self.sep_style == SeparatorStyle.LLAMA2:
130
+ seps = [self.sep, self.sep2]
131
+ if self.system_message:
132
+ ret = system_prompt
133
+ else:
134
+ ret = '[INST] '
135
+ for i, (role, message) in enumerate(self.messages):
136
+ tag = self.roles[i % 2]
137
+ if message:
138
+ if i == 0:
139
+ ret += message + ' '
140
+ else:
141
+ ret += tag + ' ' + message + seps[i % 2]
142
+ else:
143
+ ret += tag
144
+ return ret
145
+ elif self.sep_style == SeparatorStyle.CHATGLM:
146
+ # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
147
+ # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
148
+ round_add_n = 1 if self.name == 'chatglm2' else 0
149
+ if system_prompt:
150
+ ret = system_prompt + self.sep
151
+ else:
152
+ ret = ''
153
+
154
+ for i, (role, message) in enumerate(self.messages):
155
+ if i % 2 == 0:
156
+ ret += f'[Round {i//2 + round_add_n}]{self.sep}'
157
+
158
+ if message:
159
+ ret += f'{role}:{message}{self.sep}'
160
+ else:
161
+ ret += f'{role}:'
162
+ return ret
163
+ elif self.sep_style == SeparatorStyle.CHATML:
164
+ ret = '' if system_prompt == '' else system_prompt + self.sep + '\n'
165
+ for role, message in self.messages:
166
+ if message:
167
+ ret += role + '\n' + message + self.sep + '\n'
168
+ else:
169
+ ret += role + '\n'
170
+ return ret
171
+ elif self.sep_style == SeparatorStyle.CHATGLM3:
172
+ ret = ''
173
+ if self.system_message:
174
+ ret += system_prompt
175
+ for role, message in self.messages:
176
+ if message:
177
+ ret += role + '\n' + ' ' + message
178
+ else:
179
+ ret += role
180
+ return ret
181
+ elif self.sep_style == SeparatorStyle.CHATINTERN:
182
+ # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
183
+ seps = [self.sep, self.sep2]
184
+ ret = system_prompt
185
+ for i, (role, message) in enumerate(self.messages):
186
+ # if i % 2 == 0:
187
+ # ret += "<s>"
188
+ if message:
189
+ ret += role + ':' + message + seps[i % 2] + '\n'
190
+ else:
191
+ ret += role + ':'
192
+ return ret
193
+ elif self.sep_style == SeparatorStyle.DOLLY:
194
+ seps = [self.sep, self.sep2]
195
+ ret = system_prompt
196
+ for i, (role, message) in enumerate(self.messages):
197
+ if message:
198
+ ret += role + ':\n' + message + seps[i % 2]
199
+ if i % 2 == 1:
200
+ ret += '\n\n'
201
+ else:
202
+ ret += role + ':\n'
203
+ return ret
204
+ elif self.sep_style == SeparatorStyle.PHOENIX:
205
+ ret = system_prompt
206
+ for role, message in self.messages:
207
+ if message:
208
+ ret += role + ': ' + '<s>' + message + '</s>'
209
+ else:
210
+ ret += role + ': ' + '<s>'
211
+ return ret
212
+ elif self.sep_style == SeparatorStyle.ROBIN:
213
+ ret = system_prompt + self.sep
214
+ for role, message in self.messages:
215
+ if message:
216
+ ret += role + ':\n' + message + self.sep
217
+ else:
218
+ ret += role + ':\n'
219
+ return ret
220
+ elif self.sep_style == SeparatorStyle.FALCON_CHAT:
221
+ ret = ''
222
+ if self.system_message:
223
+ ret += system_prompt + self.sep
224
+ for role, message in self.messages:
225
+ if message:
226
+ ret += role + ': ' + message + self.sep
227
+ else:
228
+ ret += role + ':'
229
+
230
+ return ret
231
+ elif self.sep_style == SeparatorStyle.INTERNVL_ZH:
232
+ seps = [self.sep, self.sep2]
233
+ ret = self.system_message + seps[0]
234
+ for i, (role, message) in enumerate(self.messages):
235
+ if message:
236
+ ret += role + ': ' + message + seps[i % 2]
237
+ else:
238
+ ret += role + ':'
239
+ return ret
240
+ elif self.sep_style == SeparatorStyle.MPT:
241
+ ret = system_prompt + self.sep
242
+ for role, message in self.messages:
243
+ if message:
244
+ if type(message) is tuple:
245
+ message, _, _ = message
246
+ ret += role + message + self.sep
247
+ else:
248
+ ret += role
249
+ return ret
250
+ else:
251
+ raise ValueError(f'Invalid style: {self.sep_style}')
252
+
253
+ def set_system_message(self, system_message: str):
254
+ """Set the system message."""
255
+ self.system_message = system_message
256
+
257
+ def append_message(self, role: str, message: str):
258
+ """Append a new message."""
259
+ self.messages.append([role, message])
260
+
261
+ def update_last_message(self, message: str):
262
+ """Update the last output.
263
+
264
+ The last message is typically set to be None when constructing the prompt,
265
+ so we need to update it in-place after getting the response from a model.
266
+ """
267
+ self.messages[-1][1] = message
268
+
269
+ def to_gradio_chatbot(self):
270
+ """Convert the conversation to gradio chatbot format."""
271
+ ret = []
272
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
273
+ if i % 2 == 0:
274
+ ret.append([msg, None])
275
+ else:
276
+ ret[-1][-1] = msg
277
+ return ret
278
+
279
+ def to_openai_api_messages(self):
280
+ """Convert the conversation to OpenAI chat completion format."""
281
+ ret = [{'role': 'system', 'content': self.system_message}]
282
+
283
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
284
+ if i % 2 == 0:
285
+ ret.append({'role': 'user', 'content': msg})
286
+ else:
287
+ if msg is not None:
288
+ ret.append({'role': 'assistant', 'content': msg})
289
+ return ret
290
+
291
+ def copy(self):
292
+ return Conversation(
293
+ name=self.name,
294
+ system_template=self.system_template,
295
+ system_message=self.system_message,
296
+ roles=self.roles,
297
+ messages=[[x, y] for x, y in self.messages],
298
+ offset=self.offset,
299
+ sep_style=self.sep_style,
300
+ sep=self.sep,
301
+ sep2=self.sep2,
302
+ stop_str=self.stop_str,
303
+ stop_token_ids=self.stop_token_ids,
304
+ )
305
+
306
+ def dict(self):
307
+ return {
308
+ 'template_name': self.name,
309
+ 'system_message': self.system_message,
310
+ 'roles': self.roles,
311
+ 'messages': self.messages,
312
+ 'offset': self.offset,
313
+ }
314
+
315
+
316
+ # A global registry for all conversation templates
317
+ conv_templates: Dict[str, Conversation] = {}
318
+
319
+
320
+ def register_conv_template(template: Conversation, override: bool = False):
321
+ """Register a new conversation template."""
322
+ if not override:
323
+ assert (
324
+ template.name not in conv_templates
325
+ ), f'{template.name} has been registered.'
326
+
327
+ conv_templates[template.name] = template
328
+
329
+
330
+ def get_conv_template(name: str) -> Conversation:
331
+ """Get a conversation template."""
332
+ return conv_templates[name].copy()
333
+
334
+
335
+ # Both Hermes-2 and internlm2-chat are chatml-format conversation templates. The difference
336
+ # is that during training, the preprocessing function for the Hermes-2 template doesn't add
337
+ # <s> at the beginning of the tokenized sequence, while the internlm2-chat template does.
338
+ # Therefore, they are completely equivalent during inference.
339
+ register_conv_template(
340
+ Conversation(
341
+ name='Hermes-2',
342
+ system_template='<|im_start|>system\n{system_message}',
343
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
344
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
345
+ # system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
346
+ system_message='Bạn là một mô hình trí tuệ nhân tạo đa phương thức Tiếng Việt có tên gọi là Vintern, được phát triển bởi người Việt. Bạn là một trợ lý trí tuệ nhân tạo hữu ích và không gây hại.',
347
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
348
+ sep_style=SeparatorStyle.MPT,
349
+ sep='<|im_end|>',
350
+ stop_str='<|endoftext|>',
351
+ )
352
+ )
353
+
354
+
355
+ register_conv_template(
356
+ Conversation(
357
+ name='internlm2-chat',
358
+ system_template='<|im_start|>system\n{system_message}',
359
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
360
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
361
+ # system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
362
+ system_message='Bạn là một mô hình trí tuệ nhân tạo đa phương thức Tiếng Việt có tên gọi là Vintern, được phát triển bởi người Việt. Bạn là một trợ lý trí tuệ nhân tạo hữu ích và không gây hại.',
363
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
364
+ sep_style=SeparatorStyle.MPT,
365
+ sep='<|im_end|>',
366
+ )
367
+ )
368
+
369
+
370
+ register_conv_template(
371
+ Conversation(
372
+ name='phi3-chat',
373
+ system_template='<|system|>\n{system_message}',
374
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
375
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
376
+ # system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
377
+ system_message='Bạn là một mô hình trí tuệ nhân tạo đa phương thức Tiếng Việt có tên gọi là Vintern, được phát triển bởi người Việt. Bạn là một trợ lý trí tuệ nhân tạo hữu ích và không gây hại.',
378
+ roles=('<|user|>\n', '<|assistant|>\n'),
379
+ sep_style=SeparatorStyle.MPT,
380
+ sep='<|end|>',
381
+ )
382
+ )
383
+
384
+
385
+ register_conv_template(
386
+ Conversation(
387
+ name='internvl2_5',
388
+ system_template='<|im_start|>system\n{system_message}',
389
+ # system_message='你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
390
+ system_message='Bạn là một mô hình trí tuệ nhân tạo đa phương thức Tiếng Việt có tên gọi là Vintern, được phát triển bởi người Việt. Bạn là một trợ lý trí tuệ nhân tạo hữu ích và không gây hại.',
391
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
392
+ sep_style=SeparatorStyle.MPT,
393
+ sep='<|im_end|>\n',
394
+ )
395
+ )
modeling_intern_vit.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from einops import rearrange
13
+ from timm.models.layers import DropPath
14
+ from torch import nn
15
+ from transformers.activations import ACT2FN
16
+ from transformers.modeling_outputs import (BaseModelOutput,
17
+ BaseModelOutputWithPooling)
18
+ from transformers.modeling_utils import PreTrainedModel
19
+ from transformers.utils import logging
20
+
21
+ from .configuration_intern_vit import InternVisionConfig
22
+
23
+ try:
24
+ from flash_attn.bert_padding import pad_input, unpad_input
25
+ from flash_attn.flash_attn_interface import \
26
+ flash_attn_varlen_qkvpacked_func
27
+ has_flash_attn = True
28
+ except:
29
+ print('FlashAttention2 is not installed.')
30
+ has_flash_attn = False
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class FlashAttention(nn.Module):
36
+ """Implement the scaled dot product attention with softmax.
37
+ Arguments
38
+ ---------
39
+ softmax_scale: The temperature to use for the softmax attention.
40
+ (default: 1/sqrt(d_keys) where d_keys is computed at
41
+ runtime)
42
+ attention_dropout: The dropout rate to apply to the attention
43
+ (default: 0.0)
44
+ """
45
+
46
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
47
+ super().__init__()
48
+ self.softmax_scale = softmax_scale
49
+ self.dropout_p = attention_dropout
50
+
51
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
52
+ max_s=None, need_weights=False):
53
+ """Implements the multihead softmax attention.
54
+ Arguments
55
+ ---------
56
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
57
+ if unpadded: (nnz, 3, h, d)
58
+ key_padding_mask: a bool tensor of shape (B, S)
59
+ """
60
+ assert not need_weights
61
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
62
+ assert qkv.is_cuda
63
+
64
+ if cu_seqlens is None:
65
+ batch_size = qkv.shape[0]
66
+ seqlen = qkv.shape[1]
67
+ if key_padding_mask is None:
68
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
69
+ max_s = seqlen
70
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
71
+ device=qkv.device)
72
+ output = flash_attn_varlen_qkvpacked_func(
73
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
74
+ softmax_scale=self.softmax_scale, causal=causal
75
+ )
76
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
77
+ else:
78
+ nheads = qkv.shape[-2]
79
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
80
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
81
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
82
+ output_unpad = flash_attn_varlen_qkvpacked_func(
83
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
84
+ softmax_scale=self.softmax_scale, causal=causal
85
+ )
86
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
87
+ indices, batch_size, seqlen),
88
+ 'b s (h d) -> b s h d', h=nheads)
89
+ else:
90
+ assert max_s is not None
91
+ output = flash_attn_varlen_qkvpacked_func(
92
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
93
+ softmax_scale=self.softmax_scale, causal=causal
94
+ )
95
+
96
+ return output, None
97
+
98
+
99
+ class InternRMSNorm(nn.Module):
100
+ def __init__(self, hidden_size, eps=1e-6):
101
+ super().__init__()
102
+ self.weight = nn.Parameter(torch.ones(hidden_size))
103
+ self.variance_epsilon = eps
104
+
105
+ def forward(self, hidden_states):
106
+ input_dtype = hidden_states.dtype
107
+ hidden_states = hidden_states.to(torch.float32)
108
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
109
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
110
+ return self.weight * hidden_states.to(input_dtype)
111
+
112
+
113
+ try:
114
+ from apex.normalization import FusedRMSNorm
115
+
116
+ InternRMSNorm = FusedRMSNorm # noqa
117
+
118
+ logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
119
+ except ImportError:
120
+ # using the normal InternRMSNorm
121
+ pass
122
+ except Exception:
123
+ logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
124
+ pass
125
+
126
+
127
+ NORM2FN = {
128
+ 'rms_norm': InternRMSNorm,
129
+ 'layer_norm': nn.LayerNorm,
130
+ }
131
+
132
+
133
+ class InternVisionEmbeddings(nn.Module):
134
+ def __init__(self, config: InternVisionConfig):
135
+ super().__init__()
136
+ self.config = config
137
+ self.embed_dim = config.hidden_size
138
+ self.image_size = config.image_size
139
+ self.patch_size = config.patch_size
140
+
141
+ self.class_embedding = nn.Parameter(
142
+ torch.randn(1, 1, self.embed_dim),
143
+ )
144
+
145
+ self.patch_embedding = nn.Conv2d(
146
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
147
+ )
148
+
149
+ self.num_patches = (self.image_size // self.patch_size) ** 2
150
+ self.num_positions = self.num_patches + 1
151
+
152
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
153
+
154
+ def _get_pos_embed(self, pos_embed, H, W):
155
+ target_dtype = pos_embed.dtype
156
+ pos_embed = pos_embed.float().reshape(
157
+ 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
158
+ pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \
159
+ reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
160
+ return pos_embed
161
+
162
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
163
+ target_dtype = self.patch_embedding.weight.dtype
164
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
165
+ batch_size, _, height, width = patch_embeds.shape
166
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
167
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
168
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
169
+ position_embedding = torch.cat([
170
+ self.position_embedding[:, :1, :],
171
+ self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)
172
+ ], dim=1)
173
+ embeddings = embeddings + position_embedding.to(target_dtype)
174
+ return embeddings
175
+
176
+
177
+ class InternAttention(nn.Module):
178
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
179
+
180
+ def __init__(self, config: InternVisionConfig):
181
+ super().__init__()
182
+ self.config = config
183
+ self.embed_dim = config.hidden_size
184
+ self.num_heads = config.num_attention_heads
185
+ self.use_flash_attn = config.use_flash_attn and has_flash_attn
186
+ if config.use_flash_attn and not has_flash_attn:
187
+ print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
188
+ self.head_dim = self.embed_dim // self.num_heads
189
+ if self.head_dim * self.num_heads != self.embed_dim:
190
+ raise ValueError(
191
+ f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
192
+ f' {self.num_heads}).'
193
+ )
194
+
195
+ self.scale = self.head_dim ** -0.5
196
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
197
+ self.attn_drop = nn.Dropout(config.attention_dropout)
198
+ self.proj_drop = nn.Dropout(config.dropout)
199
+
200
+ self.qk_normalization = config.qk_normalization
201
+
202
+ if self.qk_normalization:
203
+ self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
204
+ self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
205
+
206
+ if self.use_flash_attn:
207
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
208
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
209
+
210
+ def _naive_attn(self, x):
211
+ B, N, C = x.shape
212
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
213
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
214
+
215
+ if self.qk_normalization:
216
+ B_, H_, N_, D_ = q.shape
217
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
218
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
219
+
220
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
221
+ attn = attn.softmax(dim=-1)
222
+ attn = self.attn_drop(attn)
223
+
224
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
225
+ x = self.proj(x)
226
+ x = self.proj_drop(x)
227
+ return x
228
+
229
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
230
+ qkv = self.qkv(x)
231
+ qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
232
+
233
+ if self.qk_normalization:
234
+ q, k, v = qkv.unbind(2)
235
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
236
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
237
+ qkv = torch.stack([q, k, v], dim=2)
238
+
239
+ context, _ = self.inner_attn(
240
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
241
+ )
242
+ outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
243
+ outs = self.proj_drop(outs)
244
+ return outs
245
+
246
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
247
+ x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
248
+ return x
249
+
250
+
251
+ class InternMLP(nn.Module):
252
+ def __init__(self, config: InternVisionConfig):
253
+ super().__init__()
254
+ self.config = config
255
+ self.act = ACT2FN[config.hidden_act]
256
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
257
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
258
+
259
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
260
+ hidden_states = self.fc1(hidden_states)
261
+ hidden_states = self.act(hidden_states)
262
+ hidden_states = self.fc2(hidden_states)
263
+ return hidden_states
264
+
265
+
266
+ class InternVisionEncoderLayer(nn.Module):
267
+ def __init__(self, config: InternVisionConfig, drop_path_rate: float):
268
+ super().__init__()
269
+ self.embed_dim = config.hidden_size
270
+ self.intermediate_size = config.intermediate_size
271
+ self.norm_type = config.norm_type
272
+
273
+ self.attn = InternAttention(config)
274
+ self.mlp = InternMLP(config)
275
+ self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
276
+ self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
277
+
278
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
279
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
280
+ self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
281
+ self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
282
+
283
+ def forward(
284
+ self,
285
+ hidden_states: torch.Tensor,
286
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
287
+ """
288
+ Args:
289
+ hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
290
+ """
291
+ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1)
292
+
293
+ hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2)
294
+
295
+ return hidden_states
296
+
297
+
298
+ class InternVisionEncoder(nn.Module):
299
+ """
300
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
301
+ [`InternEncoderLayer`].
302
+
303
+ Args:
304
+ config (`InternConfig`):
305
+ The corresponding vision configuration for the `InternEncoder`.
306
+ """
307
+
308
+ def __init__(self, config: InternVisionConfig):
309
+ super().__init__()
310
+ self.config = config
311
+ # stochastic depth decay rule
312
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
313
+ self.layers = nn.ModuleList([
314
+ InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
315
+ self.gradient_checkpointing = True
316
+
317
+ def forward(
318
+ self,
319
+ inputs_embeds,
320
+ output_hidden_states: Optional[bool] = None,
321
+ return_dict: Optional[bool] = None,
322
+ ) -> Union[Tuple, BaseModelOutput]:
323
+ r"""
324
+ Args:
325
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
326
+ Embedded representation of the inputs. Should be float, not int tokens.
327
+ output_hidden_states (`bool`, *optional*):
328
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
329
+ for more detail.
330
+ return_dict (`bool`, *optional*):
331
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
332
+ """
333
+ output_hidden_states = (
334
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
335
+ )
336
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
337
+
338
+ encoder_states = () if output_hidden_states else None
339
+ hidden_states = inputs_embeds
340
+
341
+ for idx, encoder_layer in enumerate(self.layers):
342
+ if output_hidden_states:
343
+ encoder_states = encoder_states + (hidden_states,)
344
+ if self.gradient_checkpointing and self.training:
345
+ layer_outputs = torch.utils.checkpoint.checkpoint(
346
+ encoder_layer,
347
+ hidden_states)
348
+ else:
349
+ layer_outputs = encoder_layer(
350
+ hidden_states,
351
+ )
352
+ hidden_states = layer_outputs
353
+
354
+ if output_hidden_states:
355
+ encoder_states = encoder_states + (hidden_states,)
356
+
357
+ if not return_dict:
358
+ return tuple(v for v in [hidden_states, encoder_states] if v is not None)
359
+ return BaseModelOutput(
360
+ last_hidden_state=hidden_states, hidden_states=encoder_states
361
+ )
362
+
363
+
364
+ class InternVisionModel(PreTrainedModel):
365
+ main_input_name = 'pixel_values'
366
+ _supports_flash_attn_2 = True
367
+ config_class = InternVisionConfig
368
+ _no_split_modules = ['InternVisionEncoderLayer']
369
+
370
+ def __init__(self, config: InternVisionConfig):
371
+ super().__init__(config)
372
+ self.config = config
373
+
374
+ self.embeddings = InternVisionEmbeddings(config)
375
+ self.encoder = InternVisionEncoder(config)
376
+
377
+ def resize_pos_embeddings(self, old_size, new_size, patch_size):
378
+ pos_emb = self.embeddings.position_embedding
379
+ _, num_positions, embed_dim = pos_emb.shape
380
+ cls_emb = pos_emb[:, :1, :]
381
+ pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
382
+ pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
383
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
384
+ pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
385
+ self.embeddings.position_embedding = nn.Parameter(pos_emb)
386
+ self.embeddings.image_size = new_size
387
+ logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
388
+
389
+ def get_input_embeddings(self):
390
+ return self.embeddings
391
+
392
+ def forward(
393
+ self,
394
+ pixel_values: Optional[torch.FloatTensor] = None,
395
+ output_hidden_states: Optional[bool] = None,
396
+ return_dict: Optional[bool] = None,
397
+ pixel_embeds: Optional[torch.FloatTensor] = None,
398
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
399
+ output_hidden_states = (
400
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
401
+ )
402
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
403
+
404
+ if pixel_values is None and pixel_embeds is None:
405
+ raise ValueError('You have to specify pixel_values or pixel_embeds')
406
+
407
+ if pixel_embeds is not None:
408
+ hidden_states = pixel_embeds
409
+ else:
410
+ if len(pixel_values.shape) == 4:
411
+ hidden_states = self.embeddings(pixel_values)
412
+ else:
413
+ raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
414
+ encoder_outputs = self.encoder(
415
+ inputs_embeds=hidden_states,
416
+ output_hidden_states=output_hidden_states,
417
+ return_dict=return_dict,
418
+ )
419
+ last_hidden_state = encoder_outputs.last_hidden_state
420
+ pooled_output = last_hidden_state[:, 0, :]
421
+
422
+ if not return_dict:
423
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
424
+
425
+ return BaseModelOutputWithPooling(
426
+ last_hidden_state=last_hidden_state,
427
+ pooler_output=pooled_output,
428
+ hidden_states=encoder_outputs.hidden_states,
429
+ attentions=encoder_outputs.attentions,
430
+ )
modeling_internvl_chat.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import warnings
8
+ from typing import List, Optional, Tuple, Union
9
+
10
+ import torch.utils.checkpoint
11
+ import transformers
12
+ from torch import nn
13
+ from torch.nn import CrossEntropyLoss
14
+ from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
15
+ Qwen2ForCausalLM)
16
+ from transformers.modeling_outputs import CausalLMOutputWithPast
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import ModelOutput, logging
19
+
20
+ from .configuration_internvl_chat import InternVLChatConfig
21
+ from .conversation import get_conv_template
22
+ from .modeling_intern_vit import InternVisionModel, has_flash_attn
23
+ from PIL import Image, ImageDraw, ImageFont
24
+ import numpy as np
25
+ import cv2
26
+ import imageio
27
+ from scipy.ndimage import gaussian_filter
28
+ from PIL import Image, ImageDraw, ImageFont
29
+ import tqdm
30
+ import random
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ def version_cmp(v1, v2, op='eq'):
35
+ import operator
36
+
37
+ from packaging import version
38
+ op_func = getattr(operator, op)
39
+ return op_func(version.parse(v1), version.parse(v2))
40
+
41
+ def draw_text_to_image(text, font, image_width=500, min_height=500, bg_color=(255, 255, 255)):
42
+ paragraphs = text.split('\n')
43
+ # Danh sách chứa tất cả các dòng văn bản sau khi được xử lý
44
+ lines = []
45
+ total_height = 0
46
+ for paragraph in paragraphs:
47
+ words = paragraph.split(' ')
48
+ current_line = ""
49
+ for word in words:
50
+ test_line = current_line + word + " "
51
+ bbox = font.getbbox(test_line)
52
+ width = bbox[2] - bbox[0]
53
+ if width <= image_width - 20: # Trừ lề khoảng 10px mỗi bên
54
+ current_line = test_line
55
+ else:
56
+ lines.append(current_line)
57
+ current_line = word + " "
58
+ total_height += font.getbbox(current_line)[3]
59
+ lines.append(current_line) # Thêm dòng cuối cùng của đoạn văn
60
+ total_height += font.getbbox(current_line)[3]
61
+ total_height = int(total_height*1.25)
62
+ if total_height < min_height:
63
+ total_height = min_height
64
+ image = Image.new('RGB', (image_width, total_height), color=bg_color)
65
+ draw = ImageDraw.Draw(image)
66
+ # Vẽ đoạn văn bản tiếng Việt lên ảnh, từng dòng một
67
+ text_color = tuple(random.randint(0, 1) for _ in range(3))
68
+ y_text = 10
69
+ for line in lines:
70
+ draw.text((10, y_text), line, font=font, fill=text_color)
71
+ y_text += font.getbbox(line)[3] * 1.2
72
+ return image
73
+
74
+ def load_image_v2(image_file, input_size=448, max_num=12):
75
+ image = Image.open(image_file).convert('RGB')
76
+ transform = build_transform(input_size=input_size)
77
+ images, target_aspect_ratio = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
78
+ pixel_values = [transform(image) for image in images]
79
+ pixel_values = torch.stack(pixel_values)
80
+ return pixel_values, target_aspect_ratio
81
+
82
+ def adjust_overlay(overlay, text_img):
83
+ h_o, w_o = overlay.shape[:2]
84
+ h_t, w_t = text_img.shape[:2]
85
+
86
+ if h_o > w_o: # Overlay là ảnh đứng
87
+ # Resize overlay sao cho h = h_t, giữ nguyên tỷ lệ
88
+ new_h = h_t
89
+ new_w = int(w_o * (new_h / h_o))
90
+ overlay_resized = cv2.resize(overlay, (new_w, new_h))
91
+ else: # Overlay là ảnh ngang
92
+ # Giữ nguyên overlay, nhưng nếu h < h_t thì thêm padding trắng
93
+ overlay_resized = overlay.copy()
94
+
95
+ # Thêm padding trắng nếu overlay có h < h_t
96
+ if overlay_resized.shape[0] < h_t:
97
+ pad_h = h_t - overlay_resized.shape[0]
98
+ padding = np.ones((pad_h, overlay_resized.shape[1], 3), dtype=np.uint8) * 255
99
+ overlay_resized = np.vstack((overlay_resized, padding)) # Padding vào dưới
100
+
101
+ # Đảm bảo overlay có cùng chiều cao với text_img
102
+ if overlay_resized.shape[0] != h_t:
103
+ overlay_resized = cv2.resize(overlay_resized, (overlay_resized.shape[1], h_t))
104
+
105
+ return overlay_resized
106
+
107
+ class InternVLChatModel(PreTrainedModel):
108
+ config_class = InternVLChatConfig
109
+ main_input_name = 'pixel_values'
110
+ base_model_prefix = 'language_model'
111
+ _supports_flash_attn_2 = True
112
+ _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'Qwen2DecoderLayer']
113
+
114
+ def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
115
+ super().__init__(config)
116
+
117
+ assert version_cmp(transformers.__version__, '4.37.0', 'ge')
118
+ image_size = config.force_image_size or config.vision_config.image_size
119
+ patch_size = config.vision_config.patch_size
120
+ self.patch_size = patch_size
121
+ self.select_layer = config.select_layer
122
+ self.template = config.template
123
+ self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
124
+ self.downsample_ratio = config.downsample_ratio
125
+ self.ps_version = config.ps_version
126
+ use_flash_attn = use_flash_attn if has_flash_attn else False
127
+ config.vision_config.use_flash_attn = True if use_flash_attn else False
128
+ config.llm_config._attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
129
+
130
+ logger.info(f'num_image_token: {self.num_image_token}')
131
+ logger.info(f'ps_version: {self.ps_version}')
132
+ if vision_model is not None:
133
+ self.vision_model = vision_model
134
+ else:
135
+ self.vision_model = InternVisionModel(config.vision_config)
136
+ if language_model is not None:
137
+ self.language_model = language_model
138
+ else:
139
+ if config.llm_config.architectures[0] == 'LlamaForCausalLM':
140
+ self.language_model = LlamaForCausalLM(config.llm_config)
141
+ elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM':
142
+ self.language_model = Qwen2ForCausalLM(config.llm_config)
143
+ else:
144
+ raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
145
+
146
+ vit_hidden_size = config.vision_config.hidden_size
147
+ llm_hidden_size = config.llm_config.hidden_size
148
+
149
+ self.mlp1 = nn.Sequential(
150
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
151
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
152
+ nn.GELU(),
153
+ nn.Linear(llm_hidden_size, llm_hidden_size)
154
+ )
155
+
156
+ self.img_context_token_id = None
157
+ self.conv_template = get_conv_template(self.template)
158
+ self.system_message = self.conv_template.system_message
159
+
160
+ def forward(
161
+ self,
162
+ pixel_values: torch.FloatTensor,
163
+ input_ids: torch.LongTensor = None,
164
+ attention_mask: Optional[torch.Tensor] = None,
165
+ position_ids: Optional[torch.LongTensor] = None,
166
+ image_flags: Optional[torch.LongTensor] = None,
167
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
168
+ labels: Optional[torch.LongTensor] = None,
169
+ use_cache: Optional[bool] = None,
170
+ output_attentions: Optional[bool] = None,
171
+ output_hidden_states: Optional[bool] = None,
172
+ return_dict: Optional[bool] = None,
173
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
174
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
175
+
176
+ image_flags = image_flags.squeeze(-1)
177
+ input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
178
+
179
+ vit_embeds = self.extract_feature(pixel_values)
180
+ vit_embeds = vit_embeds[image_flags == 1]
181
+ vit_batch_size = pixel_values.shape[0]
182
+
183
+ B, N, C = input_embeds.shape
184
+ input_embeds = input_embeds.reshape(B * N, C)
185
+
186
+ if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
187
+ print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
188
+
189
+ input_ids = input_ids.reshape(B * N)
190
+ selected = (input_ids == self.img_context_token_id)
191
+ try:
192
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
193
+ except Exception as e:
194
+ vit_embeds = vit_embeds.reshape(-1, C)
195
+ print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
196
+ f'vit_embeds.shape={vit_embeds.shape}')
197
+ n_token = selected.sum()
198
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
199
+
200
+ input_embeds = input_embeds.reshape(B, N, C)
201
+
202
+ outputs = self.language_model(
203
+ inputs_embeds=input_embeds,
204
+ attention_mask=attention_mask,
205
+ position_ids=position_ids,
206
+ past_key_values=past_key_values,
207
+ use_cache=use_cache,
208
+ output_attentions=output_attentions,
209
+ output_hidden_states=output_hidden_states,
210
+ return_dict=return_dict,
211
+ )
212
+ logits = outputs.logits
213
+
214
+ loss = None
215
+ if labels is not None:
216
+ # Shift so that tokens < n predict n
217
+ shift_logits = logits[..., :-1, :].contiguous()
218
+ shift_labels = labels[..., 1:].contiguous()
219
+ # Flatten the tokens
220
+ loss_fct = CrossEntropyLoss()
221
+ shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
222
+ shift_labels = shift_labels.view(-1)
223
+ # Enable model parallelism
224
+ shift_labels = shift_labels.to(shift_logits.device)
225
+ loss = loss_fct(shift_logits, shift_labels)
226
+
227
+ if not return_dict:
228
+ output = (logits,) + outputs[1:]
229
+ return (loss,) + output if loss is not None else output
230
+
231
+ return CausalLMOutputWithPast(
232
+ loss=loss,
233
+ logits=logits,
234
+ past_key_values=outputs.past_key_values,
235
+ hidden_states=outputs.hidden_states,
236
+ attentions=outputs.attentions,
237
+ )
238
+
239
+ def pixel_shuffle(self, x, scale_factor=0.5):
240
+ n, w, h, c = x.size()
241
+ # N, W, H, C --> N, W, H * scale, C // scale
242
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
243
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
244
+ x = x.permute(0, 2, 1, 3).contiguous()
245
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
246
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
247
+ int(c / (scale_factor * scale_factor)))
248
+ if self.ps_version == 'v1':
249
+ warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
250
+ 'which results in a transposed image.')
251
+ else:
252
+ x = x.permute(0, 2, 1, 3).contiguous()
253
+ return x
254
+
255
+ def extract_feature(self, pixel_values):
256
+ if self.select_layer == -1:
257
+ vit_embeds = self.vision_model(
258
+ pixel_values=pixel_values,
259
+ output_hidden_states=False,
260
+ return_dict=True).last_hidden_state
261
+ else:
262
+ vit_embeds = self.vision_model(
263
+ pixel_values=pixel_values,
264
+ output_hidden_states=True,
265
+ return_dict=True).hidden_states[self.select_layer]
266
+ vit_embeds = vit_embeds[:, 1:, :]
267
+
268
+ h = w = int(vit_embeds.shape[1] ** 0.5)
269
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
270
+ vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
271
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
272
+ vit_embeds = self.mlp1(vit_embeds)
273
+ return vit_embeds
274
+
275
+ def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
276
+ history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
277
+ IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
278
+ if history is not None or return_history:
279
+ print('Now multi-turn chat is not supported in batch_chat.')
280
+ raise NotImplementedError
281
+
282
+ if image_counts is not None:
283
+ num_patches_list = image_counts
284
+ print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
285
+
286
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
287
+ self.img_context_token_id = img_context_token_id
288
+
289
+ if verbose and pixel_values is not None:
290
+ image_bs = pixel_values.shape[0]
291
+ print(f'dynamic ViT batch size: {image_bs}')
292
+
293
+ queries = []
294
+ for idx, num_patches in enumerate(num_patches_list):
295
+ question = questions[idx]
296
+ if pixel_values is not None and '<image>' not in question:
297
+ question = '<image>\n' + question
298
+ template = get_conv_template(self.template)
299
+ template.system_message = self.system_message
300
+ template.append_message(template.roles[0], question)
301
+ template.append_message(template.roles[1], None)
302
+ query = template.get_prompt()
303
+
304
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
305
+ query = query.replace('<image>', image_tokens, 1)
306
+ queries.append(query)
307
+
308
+ tokenizer.padding_side = 'left'
309
+ model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
310
+ input_ids = model_inputs['input_ids'].to(self.device)
311
+ attention_mask = model_inputs['attention_mask'].to(self.device)
312
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
313
+ generation_config['eos_token_id'] = eos_token_id
314
+ generation_output = self.generate(
315
+ pixel_values=pixel_values,
316
+ input_ids=input_ids,
317
+ attention_mask=attention_mask,
318
+ **generation_config
319
+ )
320
+ responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
321
+ responses = [response.split(template.sep.strip())[0].strip() for response in responses]
322
+ return responses
323
+
324
+ def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
325
+ num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
326
+ verbose=False, attention_visualize=False,last_visualize_layers=7,raw_image_path="",target_aspect_ratio=(1,1)):
327
+
328
+ if history is None and pixel_values is not None and '<image>' not in question:
329
+ question = '<image>\n' + question
330
+
331
+ if num_patches_list is None:
332
+ num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
333
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
334
+
335
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
336
+ self.img_context_token_id = img_context_token_id
337
+
338
+ template = get_conv_template(self.template)
339
+ template.system_message = self.system_message
340
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
341
+
342
+ history = [] if history is None else history
343
+ for (old_question, old_answer) in history:
344
+ template.append_message(template.roles[0], old_question)
345
+ template.append_message(template.roles[1], old_answer)
346
+ template.append_message(template.roles[0], question)
347
+ template.append_message(template.roles[1], None)
348
+ query = template.get_prompt()
349
+
350
+ if verbose and pixel_values is not None:
351
+ image_bs = pixel_values.shape[0]
352
+ print(f'dynamic ViT batch size: {image_bs}')
353
+
354
+ for num_patches in num_patches_list:
355
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
356
+ query = query.replace('<image>', image_tokens, 1)
357
+
358
+ model_inputs = tokenizer(query, return_tensors='pt')
359
+ input_ids = model_inputs['input_ids'].to(self.device)
360
+ attention_mask = model_inputs['attention_mask'].to(self.device)
361
+ generation_config['eos_token_id'] = eos_token_id
362
+ if attention_visualize:
363
+ generation_output = self.generate(
364
+ pixel_values=pixel_values,
365
+ input_ids=input_ids,
366
+ attention_mask=attention_mask,
367
+ attention_visualize=attention_visualize,
368
+ output_hidden_states=True,
369
+ **generation_config
370
+ )
371
+ return generation_output, query
372
+ #################################### Attention visualize ##################################################
373
+ # attentions_tensors = []
374
+ # for tok_ in generation_output["attentions"]:
375
+ # attentions_tensors.append([])
376
+ # for lay_ in tok_ :
377
+ # attentions_tensors[-1].append(lay_.detach().cpu().type(torch.float).numpy())
378
+ # attention_scores = attentions_tensors
379
+ # query_ = tokenizer(query)
380
+ # start_img_token_index = int(np.where(np.array(query_["input_ids"])==tokenizer("<img>")["input_ids"][0])[0]+1)
381
+ # end_img_token_index = int(np.where(np.array(query_["input_ids"])==tokenizer("</img>")["input_ids"][0])[0]-256)
382
+ # if end_img_token_index - start_img_token_index == 0 :
383
+ # end_img_token_index = int(np.where(np.array(query_["input_ids"])==tokenizer("</img>")["input_ids"][0])[0])
384
+
385
+ # # Đọc ảnh gốc
386
+ # image = cv2.imread(raw_image_path)
387
+ # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
388
+ # # Resize ảnh nhỏ hơn để giảm dung lượng GIF
389
+ # scale_factor = 1. # Giảm 50% kích thước
390
+ # # Font chữ
391
+ # font = ImageFont.truetype("DejaVuSans.ttf", 15)
392
+ # alpha = 0.4
393
+ # # Lưu danh sách frames GIF
394
+ # visualization_frames = []
395
+ # # Chuỗi sinh ra
396
+ # generated_text = ""
397
+ # frame_step = 1
398
+ # # Lặp qua từng token
399
+ # for index_focus in tqdm.tqdm(range(0, generation_output.sequences.shape[1], frame_step)):
400
+ # token_text = tokenizer.decode(generation_output.sequences[0, index_focus])
401
+ # generated_text += token_text # Ghép chữ lại
402
+ # # Tạo heatmap trung bình từ các lớp attention
403
+ # heat_maps = []
404
+ # for i in range(1, 8):
405
+ # heat_maps.append(
406
+ # self.visualize_attention(
407
+ # attention_scores[index_focus], layer=-i, head=None,
408
+ # start_img_token_index=start_img_token_index, end_img_token_index=end_img_token_index, target_aspect_ratio=target_aspect_ratio
409
+ # )[0]
410
+ # )
411
+ # heatmap = np.array(heat_maps).mean(0)
412
+ # # Resize heatmap về kích thước ảnh gốc
413
+ # heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_CUBIC)
414
+ # # Làm mượt heatmap
415
+ # heatmap_smooth = gaussian_filter(heatmap, sigma=1)
416
+ # # Chuẩn hóa heatmap về 0-255
417
+ # heatmap_norm = cv2.normalize(heatmap_smooth, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
418
+ # heatmap_color = cv2.applyColorMap(heatmap_norm, cv2.COLORMAP_JET)
419
+ # heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
420
+ # # Overlay ảnh heatmap lên ảnh gốc
421
+ # overlay = cv2.addWeighted(image, 1 - alpha, heatmap_color, alpha, 0)
422
+ # # Tạo ảnh chứa text b��n phải
423
+ # text_img = draw_text_to_image(generated_text, font, image_width=600, min_height=500)
424
+ # text_img = np.array(text_img)
425
+ # # text_img = cv2.resize(np.array(text_img),(overlay.shape[1],overlay.shape[0]))
426
+ # # combined_image = np.hstack((overlay, text_img))
427
+ # ## Đảm bảo overlay và text_img có cùng kích thước
428
+ # overlay_adjusted = adjust_overlay(overlay, text_img)
429
+ # # Ghép ảnh
430
+ # combined_image = np.hstack((overlay_adjusted, text_img))
431
+ # # Lưu vào danh sách frames
432
+ # visualization_frames.append(combined_image)
433
+
434
+ # generation_output = generation_output.sequences
435
+ # response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
436
+ # response = response.split(template.sep.strip())[0].strip()
437
+ # history.append((question, response))
438
+ # if return_history:
439
+ # return response, history, visualization_frames
440
+ # else:
441
+ # query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
442
+ # query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
443
+ # if verbose:
444
+ # print(query_to_print, response)
445
+ # return response, visualization_frames
446
+ ############################################################################################################
447
+ else:
448
+ generation_output = self.generate(
449
+ pixel_values=pixel_values,
450
+ input_ids=input_ids,
451
+ attention_mask=attention_mask,
452
+ attention_visualize=attention_visualize,
453
+ **generation_config
454
+ )
455
+
456
+ response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
457
+ response = response.split(template.sep.strip())[0].strip()
458
+ history.append((question, response))
459
+ if return_history:
460
+ return response, history
461
+ else:
462
+ query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
463
+ query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
464
+ if verbose:
465
+ print(query_to_print, response)
466
+ return response
467
+
468
+ def visualize_attention(self, attention_tensor,layer=0, head=None, start_img_token_index=0, end_img_token_index=0, target_aspect_ratio=(0,0)):
469
+ """Vẽ heatmap của attention scores từ layer được chọn và có thể chọn head cụ thể hoặc trung bình."""
470
+ selected_layer = attention_tensor[layer] # Chọn layer cụ thể
471
+ if head is None:
472
+ averaged_attention = selected_layer.mean(axis=1).squeeze() # Trung bình qua 14 head
473
+ else:
474
+ averaged_attention = selected_layer[:, head, :, :].squeeze() # Chọn head cụ thể
475
+ averaged_attention = np.power(averaged_attention, 0.9)
476
+ heat_maps = []
477
+ for i in range(len(averaged_attention)): # Duyệt qua 3 beam
478
+ h_target_aspect_ratio = target_aspect_ratio[1]
479
+ if h_target_aspect_ratio == 0 :
480
+ h_target_aspect_ratio = 1
481
+ w_target_aspect_ratio = target_aspect_ratio[0]
482
+ if w_target_aspect_ratio == 0 :
483
+ w_target_aspect_ratio = 1
484
+ img_atten_score = averaged_attention[i].reshape(-1)[start_img_token_index:end_img_token_index]
485
+ img_atten_score = img_atten_score.reshape(h_target_aspect_ratio,w_target_aspect_ratio,16,16)
486
+ img_atten_score = np.transpose(img_atten_score, (0, 2, 1, 3)).reshape(h_target_aspect_ratio*16,w_target_aspect_ratio*16)
487
+ heat_maps.append(img_atten_score)
488
+ return heat_maps
489
+
490
+
491
+
492
+
493
+ @torch.no_grad()
494
+ def generate(
495
+ self,
496
+ pixel_values: Optional[torch.FloatTensor] = None,
497
+ input_ids: Optional[torch.FloatTensor] = None,
498
+ attention_mask: Optional[torch.LongTensor] = None,
499
+ visual_features: Optional[torch.FloatTensor] = None,
500
+ generation_config: Optional[GenerationConfig] = None,
501
+ output_hidden_states: Optional[bool] = None,
502
+ attention_visualize: Optional[bool] = False,
503
+ **generate_kwargs,
504
+ ) -> torch.LongTensor:
505
+
506
+ assert self.img_context_token_id is not None
507
+ if pixel_values is not None:
508
+ if visual_features is not None:
509
+ vit_embeds = visual_features
510
+ else:
511
+ vit_embeds = self.extract_feature(pixel_values)
512
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
513
+ B, N, C = input_embeds.shape
514
+ input_embeds = input_embeds.reshape(B * N, C)
515
+
516
+ input_ids = input_ids.reshape(B * N)
517
+ selected = (input_ids == self.img_context_token_id)
518
+ assert selected.sum() != 0
519
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
520
+
521
+ input_embeds = input_embeds.reshape(B, N, C)
522
+ else:
523
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
524
+ if attention_visualize:
525
+ output_attentions = True
526
+ return_dict_in_generate = True
527
+ else:
528
+ output_attentions = False
529
+ return_dict_in_generate = False
530
+
531
+ outputs = self.language_model.generate(
532
+ inputs_embeds=input_embeds,
533
+ attention_mask=attention_mask,
534
+ generation_config=generation_config,
535
+ output_hidden_states=output_hidden_states,
536
+ use_cache=True,
537
+ output_attentions=output_attentions,
538
+ return_dict_in_generate=return_dict_in_generate,
539
+ **generate_kwargs,
540
+ )
541
+
542
+ return outputs
543
+