SkyNait commited on
Commit
e0b2fe7
·
verified ·
1 Parent(s): caa6ba9

file upload

Browse files
Files changed (1) hide show
  1. app.py +116 -85
app.py CHANGED
@@ -27,33 +27,51 @@ has_xpu = hasattr(torch, 'xpu') and torch.xpu.is_available()
27
  def update_model(model_id, device):
28
  if model_cache['model_id'] != model_id or model_cache['device'] != device:
29
  logging.info(f'Loading model {model_id} on {device}')
30
- processor = AutoProcessor.from_pretrained(model_id)
31
- # Load model with appropriate precision for each device
32
- if device == 'cuda':
33
- # Use bfloat16 for CUDA for performance
34
- model = AutoModelForImageTextToText.from_pretrained(
35
- model_id,
36
- torch_dtype=torch.bfloat16,
37
- _attn_implementation='flash_attention_2'
38
- ).to('cuda')
39
- elif device == 'xpu' and has_xpu:
40
- # Use float32 on XPU to avoid bfloat16 layernorm issues
41
- model = AutoModelForImageTextToText.from_pretrained(
42
- model_id,
43
- torch_dtype=torch.float32
44
- ).to('xpu')
45
- else:
46
- # Default to float32 on CPU
47
- model = AutoModelForImageTextToText.from_pretrained(model_id).to('cpu')
48
- model.eval()
49
- model_cache.update({'model_id': model_id, 'processor': processor, 'model': model, 'device': device})
 
 
 
 
50
 
51
  def extract_frames_from_video(video_path, max_frames=10):
52
  """Extract frames from video file for processing"""
 
 
 
 
 
 
 
53
  cap = cv2.VideoCapture(video_path)
 
 
 
54
  frames = []
55
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
56
 
 
 
 
 
57
  # Calculate step size to extract evenly distributed frames
58
  step = max(1, frame_count // max_frames)
59
 
@@ -77,71 +95,78 @@ def extract_frames_from_video(video_path, max_frames=10):
77
  def caption_frame(frame, model_id, interval_ms, sys_prompt, usr_prompt, device):
78
  """Caption a single frame (used for webcam streaming)"""
79
  debug_msgs = []
80
- update_model(model_id, device)
81
- processor = model_cache['processor']
82
- model = model_cache['model']
 
83
 
84
- # Control capture interval
85
- time.sleep(interval_ms / 1000)
86
 
87
- # Preprocess frame
88
- t0 = time.time()
89
- rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
90
- pil_img = Image.fromarray(rgb)
91
- temp_path = 'frame.jpg'
92
- pil_img.save(temp_path, format='JPEG', quality=50)
93
- debug_msgs.append(f'Preprocess: {int((time.time()-t0)*1000)} ms')
94
 
95
- # Prepare multimodal chat messages
96
- messages = [
97
- {'role': 'system', 'content': [{'type': 'text', 'text': sys_prompt}]},
98
- {'role': 'user', 'content': [
99
- {'type': 'image', 'url': temp_path},
100
- {'type': 'text', 'text': usr_prompt}
101
- ]}
102
- ]
103
 
104
- # Tokenize and encode
105
- t1 = time.time()
106
- inputs = processor.apply_chat_template(
107
- messages,
108
- add_generation_prompt=True,
109
- tokenize=True,
110
- return_dict=True,
111
- return_tensors='pt'
112
- )
113
- # Move inputs to correct device and dtype (matching model parameters)
114
- param_dtype = next(model.parameters()).dtype
115
- cast_inputs = {}
116
- for k, v in inputs.items():
117
- if isinstance(v, torch.Tensor):
118
- if v.dtype.is_floating_point:
119
- # cast floating-point tensors to model's parameter dtype
120
- cast_inputs[k] = v.to(device=model.device, dtype=param_dtype)
 
 
 
121
  else:
122
- # move integer/mask tensors without changing dtype
123
- cast_inputs[k] = v.to(device=model.device)
124
- else:
125
- cast_inputs[k] = v
126
- inputs = cast_inputs
127
- debug_msgs.append(f'Tokenize: {int((time.time()-t1)*1000)} ms')
128
 
129
- # Inference
130
- t2 = time.time()
131
- outputs = model.generate(**inputs, do_sample=False, max_new_tokens=128)
132
- debug_msgs.append(f'Inference: {int((time.time()-t2)*1000)} ms')
133
 
134
- # Decode and strip history
135
- t3 = time.time()
136
- raw = processor.batch_decode(outputs, skip_special_tokens=True)[0]
137
- debug_msgs.append(f'Decode: {int((time.time()-t3)*1000)} ms')
138
- if "Assistant:" in raw:
139
- caption = raw.split("Assistant:")[-1].strip()
140
- else:
141
- lines = raw.splitlines()
142
- caption = lines[-1].strip() if len(lines) > 1 else raw.strip()
143
 
144
- return caption, '\n'.join(debug_msgs)
 
 
 
 
 
 
145
 
146
  @spaces.GPU
147
  def process_video_file(video_file, model_id, sys_prompt, usr_prompt, device, max_frames):
@@ -150,11 +175,13 @@ def process_video_file(video_file, model_id, sys_prompt, usr_prompt, device, max
150
  return "No video file uploaded", ""
151
 
152
  debug_msgs = []
153
- update_model(model_id, device)
154
- processor = model_cache['processor']
155
- model = model_cache['model']
156
 
157
  try:
 
 
 
 
158
  # Extract frames from video
159
  t0 = time.time()
160
  frames = extract_frames_from_video(video_file, max_frames)
@@ -171,6 +198,7 @@ def process_video_file(video_file, model_id, sys_prompt, usr_prompt, device, max
171
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
172
  pil_img = Image.fromarray(rgb)
173
  temp_path = f'frame_{i}.jpg'
 
174
  pil_img.save(temp_path, format='JPEG', quality=50)
175
 
176
  # Prepare multimodal chat messages
@@ -216,17 +244,20 @@ def process_video_file(video_file, model_id, sys_prompt, usr_prompt, device, max
216
  caption = lines[-1].strip() if len(lines) > 1 else raw.strip()
217
 
218
  captions.append(f"Frame {i+1}: {caption}")
219
-
220
- # Clean up temp file
221
- if os.path.exists(temp_path):
222
- os.remove(temp_path)
223
-
224
  debug_msgs.append(f'Frame {i+1} processed in {int((time.time()-t1)*1000)} ms')
225
 
226
  return '\n\n'.join(captions), '\n'.join(debug_msgs)
227
 
228
  except Exception as e:
229
  return f"Error processing video: {str(e)}", '\n'.join(debug_msgs)
 
 
 
 
 
 
 
 
230
 
231
  def toggle_input_mode(input_mode):
232
  """Toggle between webcam and video file input"""
 
27
  def update_model(model_id, device):
28
  if model_cache['model_id'] != model_id or model_cache['device'] != device:
29
  logging.info(f'Loading model {model_id} on {device}')
30
+ try:
31
+ processor = AutoProcessor.from_pretrained(model_id)
32
+ # Load model with appropriate precision for each device
33
+ if device == 'cuda':
34
+ # Use bfloat16 for CUDA for performance
35
+ model = AutoModelForImageTextToText.from_pretrained(
36
+ model_id,
37
+ torch_dtype=torch.bfloat16,
38
+ _attn_implementation='flash_attention_2'
39
+ ).to('cuda')
40
+ elif device == 'xpu' and has_xpu:
41
+ # Use float32 on XPU to avoid bfloat16 layernorm issues
42
+ model = AutoModelForImageTextToText.from_pretrained(
43
+ model_id,
44
+ torch_dtype=torch.float32
45
+ ).to('xpu')
46
+ else:
47
+ # Default to float32 on CPU
48
+ model = AutoModelForImageTextToText.from_pretrained(model_id).to('cpu')
49
+ model.eval()
50
+ model_cache.update({'model_id': model_id, 'processor': processor, 'model': model, 'device': device})
51
+ except Exception as e:
52
+ logging.error(f'Error loading model: {e}')
53
+ raise e
54
 
55
  def extract_frames_from_video(video_path, max_frames=10):
56
  """Extract frames from video file for processing"""
57
+ if not os.path.exists(video_path):
58
+ raise FileNotFoundError(f"Video file not found: {video_path}")
59
+
60
+ # Validate video file
61
+ if not video_path.lower().endswith(('.mp4', '.avi', '.mov', '.mkv', '.webm')):
62
+ raise ValueError("Unsupported video format. Please use MP4, AVI, MOV, MKV, or WEBM.")
63
+
64
  cap = cv2.VideoCapture(video_path)
65
+ if not cap.isOpened():
66
+ raise ValueError(f"Cannot open video file: {video_path}")
67
+
68
  frames = []
69
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
70
 
71
+ if frame_count == 0:
72
+ cap.release()
73
+ raise ValueError("Video file appears to be empty or corrupted")
74
+
75
  # Calculate step size to extract evenly distributed frames
76
  step = max(1, frame_count // max_frames)
77
 
 
95
  def caption_frame(frame, model_id, interval_ms, sys_prompt, usr_prompt, device):
96
  """Caption a single frame (used for webcam streaming)"""
97
  debug_msgs = []
98
+ try:
99
+ update_model(model_id, device)
100
+ processor = model_cache['processor']
101
+ model = model_cache['model']
102
 
103
+ # Control capture interval
104
+ time.sleep(interval_ms / 1000)
105
 
106
+ # Preprocess frame
107
+ t0 = time.time()
108
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
109
+ pil_img = Image.fromarray(rgb)
110
+ temp_path = 'frame.jpg'
111
+ pil_img.save(temp_path, format='JPEG', quality=50)
112
+ debug_msgs.append(f'Preprocess: {int((time.time()-t0)*1000)} ms')
113
 
114
+ # Prepare multimodal chat messages
115
+ messages = [
116
+ {'role': 'system', 'content': [{'type': 'text', 'text': sys_prompt}]},
117
+ {'role': 'user', 'content': [
118
+ {'type': 'image', 'url': temp_path},
119
+ {'type': 'text', 'text': usr_prompt}
120
+ ]}
121
+ ]
122
 
123
+ # Tokenize and encode
124
+ t1 = time.time()
125
+ inputs = processor.apply_chat_template(
126
+ messages,
127
+ add_generation_prompt=True,
128
+ tokenize=True,
129
+ return_dict=True,
130
+ return_tensors='pt'
131
+ )
132
+ # Move inputs to correct device and dtype (matching model parameters)
133
+ param_dtype = next(model.parameters()).dtype
134
+ cast_inputs = {}
135
+ for k, v in inputs.items():
136
+ if isinstance(v, torch.Tensor):
137
+ if v.dtype.is_floating_point:
138
+ # cast floating-point tensors to model's parameter dtype
139
+ cast_inputs[k] = v.to(device=model.device, dtype=param_dtype)
140
+ else:
141
+ # move integer/mask tensors without changing dtype
142
+ cast_inputs[k] = v.to(device=model.device)
143
  else:
144
+ cast_inputs[k] = v
145
+ inputs = cast_inputs
146
+ debug_msgs.append(f'Tokenize: {int((time.time()-t1)*1000)} ms')
 
 
 
147
 
148
+ # Inference
149
+ t2 = time.time()
150
+ outputs = model.generate(**inputs, do_sample=False, max_new_tokens=128)
151
+ debug_msgs.append(f'Inference: {int((time.time()-t2)*1000)} ms')
152
 
153
+ # Decode and strip history
154
+ t3 = time.time()
155
+ raw = processor.batch_decode(outputs, skip_special_tokens=True)[0]
156
+ debug_msgs.append(f'Decode: {int((time.time()-t3)*1000)} ms')
157
+ if "Assistant:" in raw:
158
+ caption = raw.split("Assistant:")[-1].strip()
159
+ else:
160
+ lines = raw.splitlines()
161
+ caption = lines[-1].strip() if len(lines) > 1 else raw.strip()
162
 
163
+ # Clean up temp file
164
+ if os.path.exists(temp_path):
165
+ os.remove(temp_path)
166
+
167
+ return caption, '\n'.join(debug_msgs)
168
+ except Exception as e:
169
+ return f"Error: {str(e)}", '\n'.join(debug_msgs)
170
 
171
  @spaces.GPU
172
  def process_video_file(video_file, model_id, sys_prompt, usr_prompt, device, max_frames):
 
175
  return "No video file uploaded", ""
176
 
177
  debug_msgs = []
178
+ temp_files = [] # Track temporary files for cleanup
 
 
179
 
180
  try:
181
+ update_model(model_id, device)
182
+ processor = model_cache['processor']
183
+ model = model_cache['model']
184
+
185
  # Extract frames from video
186
  t0 = time.time()
187
  frames = extract_frames_from_video(video_file, max_frames)
 
198
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
199
  pil_img = Image.fromarray(rgb)
200
  temp_path = f'frame_{i}.jpg'
201
+ temp_files.append(temp_path) # Track for cleanup
202
  pil_img.save(temp_path, format='JPEG', quality=50)
203
 
204
  # Prepare multimodal chat messages
 
244
  caption = lines[-1].strip() if len(lines) > 1 else raw.strip()
245
 
246
  captions.append(f"Frame {i+1}: {caption}")
 
 
 
 
 
247
  debug_msgs.append(f'Frame {i+1} processed in {int((time.time()-t1)*1000)} ms')
248
 
249
  return '\n\n'.join(captions), '\n'.join(debug_msgs)
250
 
251
  except Exception as e:
252
  return f"Error processing video: {str(e)}", '\n'.join(debug_msgs)
253
+ finally:
254
+ # Clean up all temporary files
255
+ for temp_file in temp_files:
256
+ if os.path.exists(temp_file):
257
+ try:
258
+ os.remove(temp_file)
259
+ except Exception as cleanup_error:
260
+ logging.warning(f"Failed to cleanup {temp_file}: {cleanup_error}")
261
 
262
  def toggle_input_mode(input_mode):
263
  """Toggle between webcam and video file input"""