littlebird13 commited on
Commit
a4fd82c
1 Parent(s): b417b11

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +286 -0
app.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import copy
7
+ import re
8
+ import os
9
+ from argparse import ArgumentParser
10
+ from threading import Thread
11
+ import spaces
12
+
13
+ import gradio as gr
14
+ from qwen_vl_utils import process_vision_info
15
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, TextIteratorStreamer
16
+
17
+ DEFAULT_CKPT_PATH = 'Qwen/Qwen2-VL-7B-Instruct'
18
+
19
+
20
+ def _get_args():
21
+ parser = ArgumentParser()
22
+
23
+ parser.add_argument('-c',
24
+ '--checkpoint-path',
25
+ type=str,
26
+ default=DEFAULT_CKPT_PATH,
27
+ help='Checkpoint name or path, default to %(default)r')
28
+ parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only')
29
+
30
+ parser.add_argument('--share',
31
+ action='store_true',
32
+ default=False,
33
+ help='Create a publicly shareable link for the interface.')
34
+ parser.add_argument('--inbrowser',
35
+ action='store_true',
36
+ default=False,
37
+ help='Automatically launch the interface in a new tab on the default browser.')
38
+ parser.add_argument('--server-port', type=int, default=7860, help='Demo server port.')
39
+ parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Demo server name.')
40
+
41
+ args = parser.parse_args()
42
+ return args
43
+
44
+
45
+ def _load_model_processor(args):
46
+ if args.cpu_only:
47
+ device_map = 'cpu'
48
+ else:
49
+ device_map = 'auto'
50
+
51
+ # default: Load the model on the available device(s)
52
+ # model = Qwen2VLForConditionalGeneration.from_pretrained(args.checkpoint_path, device_map=device_map)
53
+
54
+ # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
55
+ model = Qwen2VLForConditionalGeneration.from_pretrained(args.checkpoint_path,
56
+ torch_dtype='auto',
57
+ attn_implementation='flash_attention_2',
58
+ device_map=device_map)
59
+
60
+ processor = AutoProcessor.from_pretrained(args.checkpoint_path)
61
+ return model, processor
62
+
63
+
64
+ def _parse_text(text):
65
+ lines = text.split('\n')
66
+ lines = [line for line in lines if line != '']
67
+ count = 0
68
+ for i, line in enumerate(lines):
69
+ if '```' in line:
70
+ count += 1
71
+ items = line.split('`')
72
+ if count % 2 == 1:
73
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
74
+ else:
75
+ lines[i] = '<br></code></pre>'
76
+ else:
77
+ if i > 0:
78
+ if count % 2 == 1:
79
+ line = line.replace('`', r'\`')
80
+ line = line.replace('<', '&lt;')
81
+ line = line.replace('>', '&gt;')
82
+ line = line.replace(' ', '&nbsp;')
83
+ line = line.replace('*', '&ast;')
84
+ line = line.replace('_', '&lowbar;')
85
+ line = line.replace('-', '&#45;')
86
+ line = line.replace('.', '&#46;')
87
+ line = line.replace('!', '&#33;')
88
+ line = line.replace('(', '&#40;')
89
+ line = line.replace(')', '&#41;')
90
+ line = line.replace('$', '&#36;')
91
+ lines[i] = '<br>' + line
92
+ text = ''.join(lines)
93
+ return text
94
+
95
+
96
+ def _remove_image_special(text):
97
+ text = text.replace('<ref>', '').replace('</ref>', '')
98
+ return re.sub(r'<box>.*?(</box>|$)', '', text)
99
+
100
+
101
+ def is_video_file(filename):
102
+ video_extensions = ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg']
103
+ return any(filename.lower().endswith(ext) for ext in video_extensions)
104
+
105
+
106
+ def transform_messages(original_messages):
107
+ transformed_messages = []
108
+ for message in original_messages:
109
+ new_content = []
110
+ for item in message['content']:
111
+ if 'image' in item:
112
+ new_item = {'type': 'image', 'image': item['image']}
113
+ elif 'text' in item:
114
+ new_item = {'type': 'text', 'text': item['text']}
115
+ elif 'video' in item:
116
+ new_item = {'type': 'video', 'video': item['video']}
117
+ else:
118
+ continue
119
+ new_content.append(new_item)
120
+
121
+ new_message = {'role': message['role'], 'content': new_content}
122
+ transformed_messages.append(new_message)
123
+
124
+ return transformed_messages
125
+
126
+
127
+ def _launch_demo(args, model, processor):
128
+
129
+ @spaces.GPU
130
+ def call_local_model(model, processor, messages):
131
+
132
+ messages = transform_messages(messages)
133
+
134
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
135
+ image_inputs, video_inputs = process_vision_info(messages)
136
+ inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt').to("cuda")
137
+
138
+ tokenizer = processor.tokenizer
139
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
140
+
141
+ gen_kwargs = {'max_new_tokens': 512, 'streamer': streamer, **inputs}
142
+
143
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
144
+ thread.start()
145
+
146
+ generated_text = ''
147
+ for new_text in streamer:
148
+ generated_text += new_text
149
+ yield generated_text
150
+
151
+ def create_predict_fn():
152
+
153
+ def predict(_chatbot, task_history):
154
+ nonlocal model, processor
155
+ chat_query = _chatbot[-1][0]
156
+ query = task_history[-1][0]
157
+ if len(chat_query) == 0:
158
+ _chatbot.pop()
159
+ task_history.pop()
160
+ return _chatbot
161
+ print('User: ' + _parse_text(query))
162
+ history_cp = copy.deepcopy(task_history)
163
+ full_response = ''
164
+ messages = []
165
+ content = []
166
+ for q, a in history_cp:
167
+ if isinstance(q, (tuple, list)):
168
+ if is_video_file(q[0]):
169
+ content.append({'video': f'file://{q[0]}'})
170
+ else:
171
+ content.append({'image': f'file://{q[0]}'})
172
+ else:
173
+ content.append({'text': q})
174
+ messages.append({'role': 'user', 'content': content})
175
+ messages.append({'role': 'assistant', 'content': [{'text': a}]})
176
+ content = []
177
+ messages.pop()
178
+
179
+ for response in call_local_model(model, processor, messages):
180
+ _chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response)))
181
+
182
+ yield _chatbot
183
+ full_response = _parse_text(response)
184
+
185
+ task_history[-1] = (query, full_response)
186
+ print('Qwen-VL-Chat: ' + _parse_text(full_response))
187
+ yield _chatbot
188
+
189
+ return predict
190
+
191
+ def create_regenerate_fn():
192
+
193
+ def regenerate(_chatbot, task_history):
194
+ nonlocal model, processor
195
+ if not task_history:
196
+ return _chatbot
197
+ item = task_history[-1]
198
+ if item[1] is None:
199
+ return _chatbot
200
+ task_history[-1] = (item[0], None)
201
+ chatbot_item = _chatbot.pop(-1)
202
+ if chatbot_item[0] is None:
203
+ _chatbot[-1] = (_chatbot[-1][0], None)
204
+ else:
205
+ _chatbot.append((chatbot_item[0], None))
206
+ _chatbot_gen = predict(_chatbot, task_history)
207
+ for _chatbot in _chatbot_gen:
208
+ yield _chatbot
209
+
210
+ return regenerate
211
+
212
+ predict = create_predict_fn()
213
+ regenerate = create_regenerate_fn()
214
+
215
+ def add_text(history, task_history, text):
216
+ task_text = text
217
+ history = history if history is not None else []
218
+ task_history = task_history if task_history is not None else []
219
+ history = history + [(_parse_text(text), None)]
220
+ task_history = task_history + [(task_text, None)]
221
+ return history, task_history, ''
222
+
223
+ def add_file(history, task_history, file):
224
+ history = history if history is not None else []
225
+ task_history = task_history if task_history is not None else []
226
+ history = history + [((file.name,), None)]
227
+ task_history = task_history + [((file.name,), None)]
228
+ return history, task_history
229
+
230
+ def reset_user_input():
231
+ return gr.update(value='')
232
+
233
+ def reset_state(task_history):
234
+ task_history.clear()
235
+ return []
236
+
237
+ with gr.Blocks() as demo:
238
+ gr.Markdown("""\
239
+ <p align="center"><img src="https://qianwen-res.oss-accelerate-overseas.aliyuncs.com/Qwen2-VL/qwen2VL_logo.png" style="height: 80px"/><p>"""
240
+ )
241
+ gr.Markdown("""<center><font size=8>Qwen2-VL</center>""")
242
+ gr.Markdown("""\
243
+ <center><font size=3>This WebUI is based on Qwen2-VL, developed by Alibaba Cloud.</center>""")
244
+ gr.Markdown("""<center><font size=3>本WebUI基于Qwen2-VL。</center>""")
245
+
246
+ chatbot = gr.Chatbot(label='Qwen2-VL', elem_classes='control-height', height=500)
247
+ query = gr.Textbox(lines=2, label='Input')
248
+ task_history = gr.State([])
249
+
250
+ with gr.Row():
251
+ addfile_btn = gr.UploadButton('📁 Upload (上传文件)', file_types=['image', 'video'])
252
+ submit_btn = gr.Button('🚀 Submit (发送)')
253
+ regen_btn = gr.Button('🤔️ Regenerate (重试)')
254
+ empty_bin = gr.Button('🧹 Clear History (清除历史)')
255
+
256
+ submit_btn.click(add_text, [chatbot, task_history, query],
257
+ [chatbot, task_history]).then(predict, [chatbot, task_history], [chatbot], show_progress=True)
258
+ submit_btn.click(reset_user_input, [], [query])
259
+ empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
260
+ regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
261
+ addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
262
+
263
+ gr.Markdown("""\
264
+ <font size=2>Note: This demo is governed by the original license of Qwen2-VL. \
265
+ We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \
266
+ including hate speech, violence, pornography, deception, etc. \
267
+ (注:本演示受Qwen2-VL的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\
268
+ 包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""")
269
+
270
+ demo.queue().launch(
271
+ share=args.share,
272
+ inbrowser=args.inbrowser,
273
+ server_port=args.server_port,
274
+ server_name=args.server_name,
275
+ )
276
+
277
+
278
+ def main():
279
+ args = _get_args()
280
+ model, processor = _load_model_processor(args)
281
+ _launch_demo(args, model, processor)
282
+
283
+
284
+ if __name__ == '__main__':
285
+ main()
286
+