VictorChew
commited on
Upload folder using huggingface_hub
Browse files- README.md +8 -5
- config.json +2 -2
- configuration_intern_vit.py +1 -0
- configuration_internvl_chat.py +3 -3
- conversation.py +16 -19
- generation_config.json +1 -1
- model.safetensors +1 -1
- modeling_intern_vit.py +1 -0
- modeling_internvl_chat.py +13 -14
README.md
CHANGED
@@ -13,7 +13,7 @@ license: apache-2.0
|
|
13 |
|
14 |
[[ Github Repo ]](https://github.com/UniModal4Reasoning/StructEqTable-Deploy) [[ Related Paper ]](https://arxiv.org/abs/2406.11633) [[ Website ]](https://unimodal4reasoning.github.io/DocGenome_page/)
|
15 |
|
16 |
-
[[ Dataset🤗 ]](https://huggingface.co/datasets/U4R/DocGenome/tree/main) [[ Models🤗 ]](https://huggingface.co/U4R/StructTable-InternVL2-1B/tree/main)
|
17 |
|
18 |
|
19 |
</div>
|
@@ -24,7 +24,9 @@ Welcome to the official repository of StructEqTable-Deploy, a solution that conv
|
|
24 |
Table is an effective way to represent structured data in scientific publications, financial statements, invoices, web pages, and many other scenarios. Extracting tabular data from a visual table image and performing the downstream reasoning tasks according to the extracted data is challenging, mainly due to that tables often present complicated column and row headers with spanning cell operation. To address these challenges, we present TableX, a large-scale multi-modal table benchmark extracted from [DocGenome benchmark](https://unimodal4reasoning.github.io/DocGenome_page/) for table pre-training, comprising more than 2 million high-quality Image-LaTeX pair data covering 156 disciplinary classes. Besides, benefiting from such large-scale data, we train an end-to-end model, StructEqTable, which provides the capability to precisely obtain the corresponding LaTeX description from a visual table image and perform multiple table-related reasoning tasks, including structural extraction and question answering, broadening its application scope and potential.
|
25 |
|
26 |
## Changelog
|
27 |
-
- [2024/
|
|
|
|
|
28 |
|
29 |
Thanks to IntenrVL2 powerful foundational capabilities, and through fine-tuning on the synthetic tabular data and DocGenome dataset, StructTable can convert table image into various common table formats including LaTeX, HTML, and Markdown. Moreover, inference speed has been significantly improved compared to the v0.2 version.
|
30 |
- [2024/8/22] We have released our StructTable-base-v0.2, fine-tuned on the DocGenome dataset. This version features improved inference speed and robustness, achieved through data augmentation and reduced image token num.
|
@@ -62,9 +64,10 @@ pip install struct-eqtable==0.3.0
|
|
62 |
|
63 |
| Base Model | Model Size | Training Data | Data Augmentation | LMDeploy | TensorRT | HuggingFace |
|
64 |
|---------------------|------------|------------------|-------------------|----------|----------|-------------------|
|
65 |
-
| InternVL2-1B | ~1B | DocGenome and Synthetic Data | ✔ | ✔ | | [StructTable v0.
|
66 |
-
|
|
67 |
-
| Pix2Struct-base | ~300M | DocGenome | | | ✔ | [StructTable v0.
|
|
|
68 |
|
69 |
|
70 |
|
|
|
13 |
|
14 |
[[ Github Repo ]](https://github.com/UniModal4Reasoning/StructEqTable-Deploy) [[ Related Paper ]](https://arxiv.org/abs/2406.11633) [[ Website ]](https://unimodal4reasoning.github.io/DocGenome_page/)
|
15 |
|
16 |
+
[[ Dataset🤗 ]](https://huggingface.co/datasets/U4R/DocGenome/tree/main) [[ Models🤗 ]](https://huggingface.co/U4R/StructTable-InternVL2-1B/tree/main) [[ Demo💬 ]](https://www.modelscope.cn/studios/HongbinZhou/StructEqTable-Demo/)
|
17 |
|
18 |
|
19 |
</div>
|
|
|
24 |
Table is an effective way to represent structured data in scientific publications, financial statements, invoices, web pages, and many other scenarios. Extracting tabular data from a visual table image and performing the downstream reasoning tasks according to the extracted data is challenging, mainly due to that tables often present complicated column and row headers with spanning cell operation. To address these challenges, we present TableX, a large-scale multi-modal table benchmark extracted from [DocGenome benchmark](https://unimodal4reasoning.github.io/DocGenome_page/) for table pre-training, comprising more than 2 million high-quality Image-LaTeX pair data covering 156 disciplinary classes. Besides, benefiting from such large-scale data, we train an end-to-end model, StructEqTable, which provides the capability to precisely obtain the corresponding LaTeX description from a visual table image and perform multiple table-related reasoning tasks, including structural extraction and question answering, broadening its application scope and potential.
|
25 |
|
26 |
## Changelog
|
27 |
+
- [2024/12/12] 🔥 We have released latest model **[StructTable-InternVL2-1B v0.2](https://huggingface.co/U4R/StructTable-InternVL2-1B/tree/main)** with enhanced recognition stability for HTML and Markdown formats!
|
28 |
+
|
29 |
+
- [2024/10/19] We have released our latest model StructTable-InternVL2-1B!
|
30 |
|
31 |
Thanks to IntenrVL2 powerful foundational capabilities, and through fine-tuning on the synthetic tabular data and DocGenome dataset, StructTable can convert table image into various common table formats including LaTeX, HTML, and Markdown. Moreover, inference speed has been significantly improved compared to the v0.2 version.
|
32 |
- [2024/8/22] We have released our StructTable-base-v0.2, fine-tuned on the DocGenome dataset. This version features improved inference speed and robustness, achieved through data augmentation and reduced image token num.
|
|
|
64 |
|
65 |
| Base Model | Model Size | Training Data | Data Augmentation | LMDeploy | TensorRT | HuggingFace |
|
66 |
|---------------------|------------|------------------|-------------------|----------|----------|-------------------|
|
67 |
+
| InternVL2-1B | ~1B | DocGenome and Synthetic Data | ✔ | ✔ | | [StructTable-InternVL2-1B v0.2](https://huggingface.co/U4R/StructTable-InternVL2-1B/tree/main) |
|
68 |
+
| InternVL2-1B | ~1B | DocGenome and Synthetic Data | ✔ | ✔ | | [StructTable-InternVL2-1B v0.1](https://huggingface.co/U4R/StructTable-InternVL2-1B/tree/v0.1) |
|
69 |
+
| Pix2Struct-base | ~300M | DocGenome | ✔ | | ✔ | [StructTable-base v0.2](https://huggingface.co/U4R/StructTable-base/tree/v0.2) |
|
70 |
+
| Pix2Struct-base | ~300M | DocGenome | | | ✔ | [StructTable-base v0.1](https://huggingface.co/U4R/StructTable-base/tree/v0.1) |
|
71 |
|
72 |
|
73 |
|
config.json
CHANGED
@@ -87,7 +87,7 @@
|
|
87 |
"top_p": 1.0,
|
88 |
"torch_dtype": "bfloat16",
|
89 |
"torchscript": false,
|
90 |
-
"transformers_version": "4.44.
|
91 |
"typical_p": 1.0,
|
92 |
"use_bfloat16": true,
|
93 |
"use_cache": false,
|
@@ -185,7 +185,7 @@
|
|
185 |
"top_p": 1.0,
|
186 |
"torch_dtype": "bfloat16",
|
187 |
"torchscript": false,
|
188 |
-
"transformers_version": "4.44.
|
189 |
"typical_p": 1.0,
|
190 |
"use_bfloat16": true,
|
191 |
"use_flash_attn": true
|
|
|
87 |
"top_p": 1.0,
|
88 |
"torch_dtype": "bfloat16",
|
89 |
"torchscript": false,
|
90 |
+
"transformers_version": "4.44.2",
|
91 |
"typical_p": 1.0,
|
92 |
"use_bfloat16": true,
|
93 |
"use_cache": false,
|
|
|
185 |
"top_p": 1.0,
|
186 |
"torch_dtype": "bfloat16",
|
187 |
"torchscript": false,
|
188 |
+
"transformers_version": "4.44.2",
|
189 |
"typical_p": 1.0,
|
190 |
"use_bfloat16": true,
|
191 |
"use_flash_attn": true
|
configuration_intern_vit.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3 |
# Copyright (c) 2024 OpenGVLab
|
4 |
# Licensed under The MIT License [see LICENSE for details]
|
5 |
# --------------------------------------------------------
|
|
|
6 |
import os
|
7 |
from typing import Union
|
8 |
|
|
|
3 |
# Copyright (c) 2024 OpenGVLab
|
4 |
# Licensed under The MIT License [see LICENSE for details]
|
5 |
# --------------------------------------------------------
|
6 |
+
|
7 |
import os
|
8 |
from typing import Union
|
9 |
|
configuration_internvl_chat.py
CHANGED
@@ -46,12 +46,12 @@ class InternVLChatConfig(PretrainedConfig):
|
|
46 |
logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
|
47 |
|
48 |
self.vision_config = InternVisionConfig(**vision_config)
|
49 |
-
if llm_config
|
50 |
self.llm_config = LlamaConfig(**llm_config)
|
51 |
-
elif llm_config
|
52 |
self.llm_config = Qwen2Config(**llm_config)
|
53 |
else:
|
54 |
-
raise ValueError('Unsupported architecture: {}'.format(llm_config
|
55 |
self.use_backbone_lora = use_backbone_lora
|
56 |
self.use_llm_lora = use_llm_lora
|
57 |
self.select_layer = select_layer
|
|
|
46 |
logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
|
47 |
|
48 |
self.vision_config = InternVisionConfig(**vision_config)
|
49 |
+
if llm_config.get('architectures')[0] == 'LlamaForCausalLM':
|
50 |
self.llm_config = LlamaConfig(**llm_config)
|
51 |
+
elif llm_config.get('architectures')[0] == 'Qwen2ForCausalLM':
|
52 |
self.llm_config = Qwen2Config(**llm_config)
|
53 |
else:
|
54 |
+
raise ValueError('Unsupported architecture: {}'.format(llm_config.get('architectures')[0]))
|
55 |
self.use_backbone_lora = use_backbone_lora
|
56 |
self.use_llm_lora = use_llm_lora
|
57 |
self.select_layer = select_layer
|
conversation.py
CHANGED
@@ -3,11 +3,13 @@ Conversation prompt templates.
|
|
3 |
|
4 |
We kindly request that you import fastchat instead of copying this file if you wish to use it.
|
5 |
If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
|
|
|
|
|
6 |
"""
|
7 |
|
8 |
import dataclasses
|
9 |
from enum import IntEnum, auto
|
10 |
-
from typing import
|
11 |
|
12 |
|
13 |
class SeparatorStyle(IntEnum):
|
@@ -340,17 +342,10 @@ register_conv_template(
|
|
340 |
system_template='<|im_start|>system\n{system_message}',
|
341 |
# note: The new system prompt was not used here to avoid changes in benchmark performance.
|
342 |
# system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
|
343 |
-
|
344 |
-
system_message='You are a Table Image to LaTeX/Markdown/HMTL Code converter.',
|
345 |
roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
|
346 |
sep_style=SeparatorStyle.MPT,
|
347 |
sep='<|im_end|>',
|
348 |
-
stop_token_ids=[
|
349 |
-
2,
|
350 |
-
6,
|
351 |
-
7,
|
352 |
-
8,
|
353 |
-
],
|
354 |
stop_str='<|endoftext|>',
|
355 |
)
|
356 |
)
|
@@ -366,11 +361,6 @@ register_conv_template(
|
|
366 |
roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
|
367 |
sep_style=SeparatorStyle.MPT,
|
368 |
sep='<|im_end|>',
|
369 |
-
stop_token_ids=[
|
370 |
-
2,
|
371 |
-
92543,
|
372 |
-
92542
|
373 |
-
]
|
374 |
)
|
375 |
)
|
376 |
|
@@ -385,10 +375,17 @@ register_conv_template(
|
|
385 |
roles=('<|user|>\n', '<|assistant|>\n'),
|
386 |
sep_style=SeparatorStyle.MPT,
|
387 |
sep='<|end|>',
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
)
|
394 |
)
|
|
|
3 |
|
4 |
We kindly request that you import fastchat instead of copying this file if you wish to use it.
|
5 |
If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
|
6 |
+
|
7 |
+
Modified from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
8 |
"""
|
9 |
|
10 |
import dataclasses
|
11 |
from enum import IntEnum, auto
|
12 |
+
from typing import Dict, List, Tuple, Union
|
13 |
|
14 |
|
15 |
class SeparatorStyle(IntEnum):
|
|
|
342 |
system_template='<|im_start|>system\n{system_message}',
|
343 |
# note: The new system prompt was not used here to avoid changes in benchmark performance.
|
344 |
# system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
|
345 |
+
system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
|
|
|
346 |
roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
|
347 |
sep_style=SeparatorStyle.MPT,
|
348 |
sep='<|im_end|>',
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
stop_str='<|endoftext|>',
|
350 |
)
|
351 |
)
|
|
|
361 |
roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
|
362 |
sep_style=SeparatorStyle.MPT,
|
363 |
sep='<|im_end|>',
|
|
|
|
|
|
|
|
|
|
|
364 |
)
|
365 |
)
|
366 |
|
|
|
375 |
roles=('<|user|>\n', '<|assistant|>\n'),
|
376 |
sep_style=SeparatorStyle.MPT,
|
377 |
sep='<|end|>',
|
378 |
+
)
|
379 |
+
)
|
380 |
+
|
381 |
+
|
382 |
+
register_conv_template(
|
383 |
+
Conversation(
|
384 |
+
name='internvl2_5',
|
385 |
+
system_template='<|im_start|>system\n{system_message}',
|
386 |
+
system_message='你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
|
387 |
+
roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
|
388 |
+
sep_style=SeparatorStyle.MPT,
|
389 |
+
sep='<|im_end|>\n',
|
390 |
)
|
391 |
)
|
generation_config.json
CHANGED
@@ -4,5 +4,5 @@
|
|
4 |
151644,
|
5 |
151645
|
6 |
],
|
7 |
-
"transformers_version": "4.44.
|
8 |
}
|
|
|
4 |
151644,
|
5 |
151645
|
6 |
],
|
7 |
+
"transformers_version": "4.44.2"
|
8 |
}
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1876395376
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:59901e249fd22bef86f66f22005e091379c92a706f55340ac5fc481d930757fb
|
3 |
size 1876395376
|
modeling_intern_vit.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3 |
# Copyright (c) 2024 OpenGVLab
|
4 |
# Licensed under The MIT License [see LICENSE for details]
|
5 |
# --------------------------------------------------------
|
|
|
6 |
from typing import Optional, Tuple, Union
|
7 |
|
8 |
import torch
|
|
|
3 |
# Copyright (c) 2024 OpenGVLab
|
4 |
# Licensed under The MIT License [see LICENSE for details]
|
5 |
# --------------------------------------------------------
|
6 |
+
|
7 |
from typing import Optional, Tuple, Union
|
8 |
|
9 |
import torch
|
modeling_internvl_chat.py
CHANGED
@@ -3,8 +3,9 @@
|
|
3 |
# Copyright (c) 2024 OpenGVLab
|
4 |
# Licensed under The MIT License [see LICENSE for details]
|
5 |
# --------------------------------------------------------
|
|
|
6 |
import warnings
|
7 |
-
from typing import
|
8 |
|
9 |
import torch.utils.checkpoint
|
10 |
import transformers
|
@@ -34,6 +35,7 @@ def version_cmp(v1, v2, op='eq'):
|
|
34 |
class InternVLChatModel(PreTrainedModel):
|
35 |
config_class = InternVLChatConfig
|
36 |
main_input_name = 'pixel_values'
|
|
|
37 |
_supports_flash_attn_2 = True
|
38 |
_no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'Qwen2DecoderLayer']
|
39 |
|
@@ -99,10 +101,11 @@ class InternVLChatModel(PreTrainedModel):
|
|
99 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
100 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
101 |
|
|
|
102 |
image_flags = image_flags.squeeze(0)
|
103 |
pixel_values = pixel_values.squeeze(0)
|
104 |
|
105 |
-
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
106 |
|
107 |
vit_embeds = self.extract_feature(pixel_values)
|
108 |
vit_embeds = vit_embeds[image_flags == 1]
|
@@ -116,7 +119,6 @@ class InternVLChatModel(PreTrainedModel):
|
|
116 |
|
117 |
input_ids = input_ids.reshape(B * N)
|
118 |
selected = (input_ids == self.img_context_token_id)
|
119 |
-
|
120 |
try:
|
121 |
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
|
122 |
except Exception as e:
|
@@ -236,9 +238,9 @@ class InternVLChatModel(PreTrainedModel):
|
|
236 |
|
237 |
tokenizer.padding_side = 'left'
|
238 |
model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
|
239 |
-
input_ids = model_inputs['input_ids'].
|
240 |
-
attention_mask = model_inputs['attention_mask'].
|
241 |
-
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
|
242 |
generation_config['eos_token_id'] = eos_token_id
|
243 |
generation_output = self.generate(
|
244 |
pixel_values=pixel_values,
|
@@ -247,7 +249,7 @@ class InternVLChatModel(PreTrainedModel):
|
|
247 |
**generation_config
|
248 |
)
|
249 |
responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
|
250 |
-
responses = [response.split(template.sep)[0].strip() for response in responses]
|
251 |
return responses
|
252 |
|
253 |
def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
|
@@ -266,7 +268,7 @@ class InternVLChatModel(PreTrainedModel):
|
|
266 |
|
267 |
template = get_conv_template(self.template)
|
268 |
template.system_message = self.system_message
|
269 |
-
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
|
270 |
|
271 |
history = [] if history is None else history
|
272 |
for (old_question, old_answer) in history:
|
@@ -285,10 +287,9 @@ class InternVLChatModel(PreTrainedModel):
|
|
285 |
query = query.replace('<image>', image_tokens, 1)
|
286 |
|
287 |
model_inputs = tokenizer(query, return_tensors='pt')
|
288 |
-
input_ids = model_inputs['input_ids'].
|
289 |
-
attention_mask = model_inputs['attention_mask'].
|
290 |
generation_config['eos_token_id'] = eos_token_id
|
291 |
-
|
292 |
generation_output = self.generate(
|
293 |
pixel_values=pixel_values,
|
294 |
input_ids=input_ids,
|
@@ -296,7 +297,7 @@ class InternVLChatModel(PreTrainedModel):
|
|
296 |
**generation_config
|
297 |
)
|
298 |
response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
|
299 |
-
response = response.split(template.sep)[0].strip()
|
300 |
history.append((question, response))
|
301 |
if return_history:
|
302 |
return response, history
|
@@ -316,7 +317,6 @@ class InternVLChatModel(PreTrainedModel):
|
|
316 |
visual_features: Optional[torch.FloatTensor] = None,
|
317 |
generation_config: Optional[GenerationConfig] = None,
|
318 |
output_hidden_states: Optional[bool] = None,
|
319 |
-
return_dict: Optional[bool] = None,
|
320 |
img_context_token_id: Optional[bool] = None,
|
321 |
**generate_kwargs,
|
322 |
) -> torch.LongTensor:
|
@@ -347,7 +347,6 @@ class InternVLChatModel(PreTrainedModel):
|
|
347 |
attention_mask=attention_mask,
|
348 |
generation_config=generation_config,
|
349 |
output_hidden_states=output_hidden_states,
|
350 |
-
return_dict=return_dict,
|
351 |
use_cache=True,
|
352 |
**generate_kwargs,
|
353 |
)
|
|
|
3 |
# Copyright (c) 2024 OpenGVLab
|
4 |
# Licensed under The MIT License [see LICENSE for details]
|
5 |
# --------------------------------------------------------
|
6 |
+
|
7 |
import warnings
|
8 |
+
from typing import List, Optional, Tuple, Union
|
9 |
|
10 |
import torch.utils.checkpoint
|
11 |
import transformers
|
|
|
35 |
class InternVLChatModel(PreTrainedModel):
|
36 |
config_class = InternVLChatConfig
|
37 |
main_input_name = 'pixel_values'
|
38 |
+
base_model_prefix = 'language_model'
|
39 |
_supports_flash_attn_2 = True
|
40 |
_no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'Qwen2DecoderLayer']
|
41 |
|
|
|
101 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
102 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
103 |
|
104 |
+
# image_flags = image_flags.squeeze(-1)
|
105 |
image_flags = image_flags.squeeze(0)
|
106 |
pixel_values = pixel_values.squeeze(0)
|
107 |
|
108 |
+
input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
|
109 |
|
110 |
vit_embeds = self.extract_feature(pixel_values)
|
111 |
vit_embeds = vit_embeds[image_flags == 1]
|
|
|
119 |
|
120 |
input_ids = input_ids.reshape(B * N)
|
121 |
selected = (input_ids == self.img_context_token_id)
|
|
|
122 |
try:
|
123 |
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
|
124 |
except Exception as e:
|
|
|
238 |
|
239 |
tokenizer.padding_side = 'left'
|
240 |
model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
|
241 |
+
input_ids = model_inputs['input_ids'].to(self.device)
|
242 |
+
attention_mask = model_inputs['attention_mask'].to(self.device)
|
243 |
+
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
|
244 |
generation_config['eos_token_id'] = eos_token_id
|
245 |
generation_output = self.generate(
|
246 |
pixel_values=pixel_values,
|
|
|
249 |
**generation_config
|
250 |
)
|
251 |
responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
|
252 |
+
responses = [response.split(template.sep.strip())[0].strip() for response in responses]
|
253 |
return responses
|
254 |
|
255 |
def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
|
|
|
268 |
|
269 |
template = get_conv_template(self.template)
|
270 |
template.system_message = self.system_message
|
271 |
+
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
|
272 |
|
273 |
history = [] if history is None else history
|
274 |
for (old_question, old_answer) in history:
|
|
|
287 |
query = query.replace('<image>', image_tokens, 1)
|
288 |
|
289 |
model_inputs = tokenizer(query, return_tensors='pt')
|
290 |
+
input_ids = model_inputs['input_ids'].to(self.device)
|
291 |
+
attention_mask = model_inputs['attention_mask'].to(self.device)
|
292 |
generation_config['eos_token_id'] = eos_token_id
|
|
|
293 |
generation_output = self.generate(
|
294 |
pixel_values=pixel_values,
|
295 |
input_ids=input_ids,
|
|
|
297 |
**generation_config
|
298 |
)
|
299 |
response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
|
300 |
+
response = response.split(template.sep.strip())[0].strip()
|
301 |
history.append((question, response))
|
302 |
if return_history:
|
303 |
return response, history
|
|
|
317 |
visual_features: Optional[torch.FloatTensor] = None,
|
318 |
generation_config: Optional[GenerationConfig] = None,
|
319 |
output_hidden_states: Optional[bool] = None,
|
|
|
320 |
img_context_token_id: Optional[bool] = None,
|
321 |
**generate_kwargs,
|
322 |
) -> torch.LongTensor:
|
|
|
347 |
attention_mask=attention_mask,
|
348 |
generation_config=generation_config,
|
349 |
output_hidden_states=output_hidden_states,
|
|
|
350 |
use_cache=True,
|
351 |
**generate_kwargs,
|
352 |
)
|