princepride commited on
Commit
b32599e
·
verified ·
1 Parent(s): a01fbc4

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +415 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tempfile
3
+ import os
4
+ import time
5
+ import re
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+ from decord import VideoReader, cpu
10
+ import torchvision.transforms as T
11
+ from torchvision.transforms.functional import InterpolationMode
12
+ from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer
13
+ from threading import Thread
14
+
15
+ # Set page configuration
16
+ st.set_page_config(page_title="Omni DeepSeek Video Analysis", layout="wide")
17
+
18
+ # Constants
19
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
20
+ IMAGENET_STD = (0.229, 0.224, 0.225)
21
+
22
+ # Add CSS for text wrapping and vertical scrollbar for the expander
23
+ st.markdown("""
24
+ <style>
25
+ .output-text {
26
+ white-space: pre-wrap !important;
27
+ word-wrap: break-word !important;
28
+ }
29
+ .streamlit-expanderContent {
30
+ white-space: pre-wrap !important;
31
+ word-wrap: break-word !important;
32
+ max-height: 100px; /* 根据需要调整高度 */
33
+ overflow-y: auto; /* 添加垂直滚动条 */
34
+ }
35
+ </style>
36
+ """, unsafe_allow_html=True)
37
+
38
+ # Model loading utilities
39
+ @st.cache_resource
40
+ def load_model_and_tokenizer():
41
+ """Load and cache the model and tokenizer"""
42
+ path = 'AlphaTok/omni-deepseek-v0'
43
+
44
+ with st.spinner("Loading model (this may take a minute)..."):
45
+ model = AutoModel.from_pretrained(
46
+ path,
47
+ torch_dtype=torch.bfloat16,
48
+ low_cpu_mem_usage=True,
49
+ use_flash_attn=True,
50
+ trust_remote_code=True
51
+ ).eval()
52
+
53
+ # Move to GPU if available
54
+ if torch.cuda.is_available():
55
+ model = model.cuda()
56
+ st.success("Model loaded on GPU")
57
+ else:
58
+ st.warning("GPU not available, running on CPU (inference will be slow)")
59
+
60
+ tokenizer = AutoTokenizer.from_pretrained(
61
+ path,
62
+ trust_remote_code=True,
63
+ use_fast=False
64
+ )
65
+
66
+ return model, tokenizer
67
+
68
+ # Video processing functions
69
+ def build_transform(input_size):
70
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
71
+ transform = T.Compose([
72
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
73
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
74
+ T.ToTensor(),
75
+ T.Normalize(mean=MEAN, std=STD)
76
+ ])
77
+ return transform
78
+
79
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
80
+ orig_width, orig_height = image.size
81
+ aspect_ratio = orig_width / orig_height
82
+ target_ratios = set(
83
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
84
+ i * j <= max_num and i * j >= min_num)
85
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
86
+
87
+ # Calculate the target aspect ratio
88
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios):
89
+ best_ratio_diff = float('inf')
90
+ best_ratio = (1, 1)
91
+ for ratio in target_ratios:
92
+ target_aspect_ratio = ratio[0] / ratio[1]
93
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
94
+ if ratio_diff < best_ratio_diff:
95
+ best_ratio_diff = ratio_diff
96
+ best_ratio = ratio
97
+ return best_ratio
98
+
99
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios)
100
+ target_width = image_size * target_aspect_ratio[0]
101
+ target_height = image_size * target_aspect_ratio[1]
102
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
103
+
104
+ resized_img = image.resize((target_width, target_height))
105
+ processed_images = []
106
+ for i in range(blocks):
107
+ box = (
108
+ (i % (target_width // image_size)) * image_size,
109
+ (i // (target_width // image_size)) * image_size,
110
+ ((i % (target_width // image_size)) + 1) * image_size,
111
+ ((i // (target_width // image_size)) + 1) * image_size
112
+ )
113
+ split_img = resized_img.crop(box)
114
+ processed_images.append(split_img)
115
+ if use_thumbnail and len(processed_images) != 1:
116
+ thumbnail_img = image.resize((image_size, image_size))
117
+ processed_images.append(thumbnail_img)
118
+ return processed_images
119
+
120
+ def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
121
+ if bound:
122
+ start, end = bound[0], bound[1]
123
+ else:
124
+ start, end = -100000, 100000
125
+ start_idx = max(first_idx, round(start * fps))
126
+ end_idx = min(round(end * fps), max_frame)
127
+ seg_size = float(end_idx - start_idx) / num_segments
128
+ frame_indices = np.array([
129
+ int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
130
+ for idx in range(num_segments)
131
+ ])
132
+ return frame_indices
133
+
134
+ def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
135
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
136
+ max_frame = len(vr) - 1
137
+ fps = float(vr.get_avg_fps())
138
+
139
+ pixel_values_list, num_patches_list = [], []
140
+ transform = build_transform(input_size=input_size)
141
+ frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
142
+ for frame_index in frame_indices:
143
+ img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
144
+ img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
145
+ pixel_values = [transform(tile) for tile in img]
146
+ pixel_values = torch.stack(pixel_values)
147
+ num_patches_list.append(pixel_values.shape[0])
148
+ pixel_values_list.append(pixel_values)
149
+ pixel_values = torch.cat(pixel_values_list)
150
+ return pixel_values, num_patches_list
151
+
152
+ # Save uploaded file to a temporary location
153
+ def save_uploaded_file(uploaded_file):
154
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp:
155
+ tmp.write(uploaded_file.getvalue())
156
+ return tmp.name
157
+
158
+ def process_video_and_run_inference(video_path, prompt, model, tokenizer):
159
+ # 加载并预处理视频
160
+ with st.spinner("Processing video..."):
161
+ pixel_values, num_patches_list = load_video(
162
+ video_path,
163
+ num_segments=16,
164
+ max_num=1
165
+ )
166
+ if torch.cuda.is_available():
167
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
168
+ else:
169
+ pixel_values = pixel_values.to(torch.bfloat16)
170
+
171
+ # 初始化用于文本生成的 streamer
172
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10)
173
+ generation_config = dict(max_new_tokens=1024, do_sample=False, streamer=streamer)
174
+
175
+ # 启动模型对话线程
176
+ thread = Thread(
177
+ target=model.chat,
178
+ kwargs=dict(
179
+ tokenizer=tokenizer,
180
+ pixel_values=pixel_values,
181
+ question=prompt,
182
+ history=None,
183
+ return_history=False,
184
+ generation_config=generation_config,
185
+ )
186
+ )
187
+ thread.start()
188
+
189
+ # 用于累积模型原始输出的变量
190
+ raw_output = ""
191
+
192
+ # 初始化状态变量,用于拆分 think 和 regular 部分
193
+ think_mode = False
194
+ think_content = ""
195
+ regular_content = ""
196
+
197
+ # 针对每个从 streamer 中获取的文本块进行处理
198
+ for new_text in streamer:
199
+ # 将原始新文本累加到 raw_output 中
200
+ raw_output += new_text
201
+
202
+ pos = 0
203
+ while pos < len(new_text):
204
+ idx_think = new_text.find("<think>", pos)
205
+ idx_think_close = new_text.find("</think>", pos)
206
+ # 如果本段中没有任何标签,则将剩余内容加入当前模式,并退出循环
207
+ if idx_think == -1 and idx_think_close == -1:
208
+ if think_mode:
209
+ think_content += new_text[pos:]
210
+ yield {"type": "think", "content": think_content}
211
+ else:
212
+ regular_content += new_text[pos:]
213
+ yield {"type": "regular", "content": regular_content}
214
+ break
215
+ # 如果 <think> 出现得更早或 </think> 不存在
216
+ if idx_think != -1 and (idx_think_close == -1 or idx_think < idx_think_close):
217
+ # 先处理标签前的内容
218
+ if think_mode:
219
+ think_content += new_text[pos:idx_think]
220
+ yield {"type": "think", "content": think_content}
221
+ else:
222
+ regular_content += new_text[pos:idx_think]
223
+ yield {"type": "regular", "content": regular_content}
224
+ pos = idx_think + len("<think>")
225
+ think_mode = True
226
+ else:
227
+ # 处理 </think> 出现的情况
228
+ if think_mode:
229
+ think_content += new_text[pos:idx_think_close]
230
+ yield {"type": "think", "content": think_content}
231
+ think_content = "" # 清空 think 内容缓存
232
+ else:
233
+ regular_content += new_text[pos:idx_think_close]
234
+ yield {"type": "regular", "content": regular_content}
235
+ pos = idx_think_close + len("</think>")
236
+ think_mode = False
237
+
238
+ thread.join() # 确保线程结束
239
+
240
+ # 在终端打印完整的模型原始输出
241
+ print("Complete raw model output:")
242
+ print(raw_output)
243
+
244
+ # Main app function
245
+ def main():
246
+ st.title("Video Analysis with Omni DeepSeek")
247
+ st.markdown("Upload a video and provide a prompt to analyze it.")
248
+
249
+ # Load model and tokenizer
250
+ model, tokenizer = load_model_and_tokenizer()
251
+
252
+ # Sidebar for inputs
253
+ with st.sidebar:
254
+ st.header("Upload and Settings")
255
+ video_file = st.file_uploader("Upload Video", type=["mp4", "avi", "mov", "mkv"])
256
+ # 添加提示词模板选择,下拉框中包含默认模板和omni-matrix模板
257
+ template_option = st.selectbox("Select Prompt Template", options=["Default", "Omni-Matrix Template"])
258
+ if template_option == "Default":
259
+ prompt = st.text_area("Enter your prompt", value="Please describe this video", height=100)
260
+ else:
261
+ prompt = st.text_area("Enter your prompt", value=f"""
262
+ Analyze the video and determine whether the user requires assistance based on the video activity type and behavior. Generate the output in the following structured JSON format:
263
+
264
+ 1. **help_needed**: A boolean value (true or false) indicating whether the user needs help based on the video content.
265
+ 2. **video_description**: A brief description of the video content.
266
+ 3. **video_type**: The type of activity in the video. Options include working, meeting, coding, gaming, watching, or other.
267
+ 4. **function_call_name**: If help_needed is true, specify the name of the function to provide assistance. Options include draft_copy (drafting a copy), assist_coding (coding assistance), web_search (web search). If no help is needed, return an empty string.
268
+ 5. **function_call_parameters**: If help is needed, provide the required parameters for the function call; otherwise, return an empty array. The parameters are defined as follows:
269
+ - **draft_copy**: Two strings - the first one is the copy subject and the second one is the copy content.
270
+ -- copy_subject(str): The subject of the copy
271
+ -- copy_content(str): The content of the copy
272
+ - **web_search**:
273
+ -- web_search_content(str): A single string containing the search query.
274
+ - **assist_coding**:
275
+ -- coding_subject(str): The subject of the code
276
+ -- coding_content(str): The content of the code
277
+
278
+ **Input Requirements:**
279
+ The input is a description of the video, and the model needs to analyze it to determine user behavior and generate a JSON response in the following format:
280
+
281
+ json
282
+ {{
283
+ "help_needed": true/false,
284
+ "video_description": "Brief description of the video content",
285
+ "video_type": "working"/"meeting"/"coding"/"gaming"/"watching"/"other",
286
+ "function_call_name": "draft_email/assist_coding/web_search",
287
+ "function_call_parameters": {{
288
+ "parameter1":"parameter1 content",
289
+ "parameter2":"parameter2 content"
290
+ }}
291
+ }}
292
+
293
+ **Examples:**
294
+ 1. If the video shows the user debugging code and repeatedly checking documentation:
295
+ json
296
+ {{
297
+ "help_needed": true,
298
+ "video_description": "The user is debugging code and may need assistance.",
299
+ "video_type": "coding",
300
+ "function_call_name": "assist_coding",
301
+ "function_call_parameters": {{
302
+ "coding_subject": "Help the user implement quicksort.",
303
+ "coding_content": "
304
+ def quicksort(arr):
305
+ if len(arr) <= 1:
306
+ return arr
307
+
308
+ pivot = arr[len(arr) // 2]
309
+ left = [x for x in arr if x < pivot]
310
+ middle = [x for x in arr if x == pivot]
311
+ right = [x for x in arr if x > pivot]
312
+ return quicksort(left) + middle + quicksort(right)
313
+ "
314
+ }}
315
+ }}
316
+
317
+ 2. If the video shows the user watching a movie and no assistance is required:
318
+ json
319
+ {{
320
+ "help_needed": false,
321
+ "video_description": "The user is watching a movie.",
322
+ "video_type": "watching",
323
+ "function_call_name": "",
324
+ "function_call_parameters": []
325
+ }}
326
+
327
+ 3. If the video shows the user writing an email and might need assistance drafting it:
328
+ json
329
+ {{
330
+ "help_needed": true,
331
+ "video_description": "The user is writing an email and may need assistance.",
332
+ "video_type": "working",
333
+ "function_call_name": "draft_copy",
334
+ "function_call_parameters": {{
335
+ "copy_subject": "Follow-up Meeting",
336
+ "copy_content": "Please confirm your availability for the next meeting."
337
+ }}
338
+ }}
339
+
340
+ 4. If the video shows the user searching for a specific topic online:
341
+ json
342
+ {{
343
+ "help_needed": true,
344
+ "video_description": "The user is searching for information online.",
345
+ "video_type": "working",
346
+ "function_call_name": "web_search",
347
+ "function_call_parameters": {{
348
+ "web_search_content": "latest AI research papers"
349
+ }}
350
+ }}
351
+ """, height=400)
352
+ run_button = st.button("Analyze Video", type="primary")
353
+
354
+ st.markdown("---")
355
+ st.markdown("### Model Information")
356
+ st.info("Using AlphaTok/omni-deepseek-v0 model")
357
+
358
+ # Main content area with two columns
359
+ col1, col2 = st.columns([1, 1])
360
+
361
+ with col1:
362
+ st.header("Input")
363
+ if video_file:
364
+ st.video(video_file)
365
+ st.text(f"Prompt: {prompt}")
366
+
367
+ with col2:
368
+ st.header("Output")
369
+ # 将 thinking 折叠框默认展开
370
+ thinking_container = st.expander("Thinking Process", expanded=True)
371
+ output_container = st.container()
372
+
373
+ if run_button and video_file and prompt:
374
+ # Save the uploaded video
375
+ video_path = save_uploaded_file(video_file)
376
+
377
+ # Create a progress bar
378
+ progress_bar = st.progress(0.0)
379
+
380
+ # Placeholders for streaming output
381
+ thinking_placeholder = thinking_container.empty()
382
+ output_placeholder = output_container.empty()
383
+
384
+ try:
385
+ progress_step = 0
386
+ # 在流式输出过程中将进度条固定显示在 90%
387
+ for result in process_video_and_run_inference(video_path, prompt, model, tokenizer):
388
+ progress_step += 1
389
+ progress_bar.progress(min(0.9, progress_step / 1024))
390
+ if result["type"] == "think":
391
+ thinking_placeholder.markdown(f"""<div class="output-text">{result['content']}</div>""", unsafe_allow_html=True)
392
+ elif result["type"] == "regular":
393
+ content = result["content"]
394
+ if re.search(r'```\s*json\s*\{', content):
395
+ json_content = re.search(r'```\s*json\s*(\{.*?\})\s*```', content, re.DOTALL)
396
+ if json_content:
397
+ output_placeholder.json(json_content.group(1))
398
+ else:
399
+ output_placeholder.markdown(f"""<div class="output-text">{content}</div>""", unsafe_allow_html=True)
400
+ else:
401
+ output_placeholder.markdown(f"""<div class="output-text">{content}</div>""", unsafe_allow_html=True)
402
+
403
+ # 模型生成结束后完成进度条更新
404
+ progress_bar.progress(1.0)
405
+ time.sleep(0.5)
406
+ progress_bar.empty()
407
+ os.unlink(video_path)
408
+
409
+ except Exception as e:
410
+ st.error(f"An error occurred: {str(e)}")
411
+ if os.path.exists(video_path):
412
+ os.unlink(video_path)
413
+
414
+ if __name__ == "__main__":
415
+ main()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ decord
2
+ transformers<4.50.0
3
+ einops
4
+ timm
5
+ accelerate>=0.26.0
6
+ sentencepiece
7
+ pandas
8
+ tqdm
9
+ flash-attn
10
+ streamlit