file upload
Browse files
app.py
CHANGED
@@ -27,33 +27,51 @@ has_xpu = hasattr(torch, 'xpu') and torch.xpu.is_available()
|
|
27 |
def update_model(model_id, device):
|
28 |
if model_cache['model_id'] != model_id or model_cache['device'] != device:
|
29 |
logging.info(f'Loading model {model_id} on {device}')
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
50 |
|
51 |
def extract_frames_from_video(video_path, max_frames=10):
|
52 |
"""Extract frames from video file for processing"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
cap = cv2.VideoCapture(video_path)
|
|
|
|
|
|
|
54 |
frames = []
|
55 |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
56 |
|
|
|
|
|
|
|
|
|
57 |
# Calculate step size to extract evenly distributed frames
|
58 |
step = max(1, frame_count // max_frames)
|
59 |
|
@@ -77,71 +95,78 @@ def extract_frames_from_video(video_path, max_frames=10):
|
|
77 |
def caption_frame(frame, model_id, interval_ms, sys_prompt, usr_prompt, device):
|
78 |
"""Caption a single frame (used for webcam streaming)"""
|
79 |
debug_msgs = []
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
83 |
|
84 |
-
|
85 |
-
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
121 |
else:
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
cast_inputs[k] = v
|
126 |
-
inputs = cast_inputs
|
127 |
-
debug_msgs.append(f'Tokenize: {int((time.time()-t1)*1000)} ms')
|
128 |
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
@spaces.GPU
|
147 |
def process_video_file(video_file, model_id, sys_prompt, usr_prompt, device, max_frames):
|
@@ -150,11 +175,13 @@ def process_video_file(video_file, model_id, sys_prompt, usr_prompt, device, max
|
|
150 |
return "No video file uploaded", ""
|
151 |
|
152 |
debug_msgs = []
|
153 |
-
|
154 |
-
processor = model_cache['processor']
|
155 |
-
model = model_cache['model']
|
156 |
|
157 |
try:
|
|
|
|
|
|
|
|
|
158 |
# Extract frames from video
|
159 |
t0 = time.time()
|
160 |
frames = extract_frames_from_video(video_file, max_frames)
|
@@ -171,6 +198,7 @@ def process_video_file(video_file, model_id, sys_prompt, usr_prompt, device, max
|
|
171 |
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
172 |
pil_img = Image.fromarray(rgb)
|
173 |
temp_path = f'frame_{i}.jpg'
|
|
|
174 |
pil_img.save(temp_path, format='JPEG', quality=50)
|
175 |
|
176 |
# Prepare multimodal chat messages
|
@@ -216,17 +244,20 @@ def process_video_file(video_file, model_id, sys_prompt, usr_prompt, device, max
|
|
216 |
caption = lines[-1].strip() if len(lines) > 1 else raw.strip()
|
217 |
|
218 |
captions.append(f"Frame {i+1}: {caption}")
|
219 |
-
|
220 |
-
# Clean up temp file
|
221 |
-
if os.path.exists(temp_path):
|
222 |
-
os.remove(temp_path)
|
223 |
-
|
224 |
debug_msgs.append(f'Frame {i+1} processed in {int((time.time()-t1)*1000)} ms')
|
225 |
|
226 |
return '\n\n'.join(captions), '\n'.join(debug_msgs)
|
227 |
|
228 |
except Exception as e:
|
229 |
return f"Error processing video: {str(e)}", '\n'.join(debug_msgs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
def toggle_input_mode(input_mode):
|
232 |
"""Toggle between webcam and video file input"""
|
|
|
27 |
def update_model(model_id, device):
|
28 |
if model_cache['model_id'] != model_id or model_cache['device'] != device:
|
29 |
logging.info(f'Loading model {model_id} on {device}')
|
30 |
+
try:
|
31 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
32 |
+
# Load model with appropriate precision for each device
|
33 |
+
if device == 'cuda':
|
34 |
+
# Use bfloat16 for CUDA for performance
|
35 |
+
model = AutoModelForImageTextToText.from_pretrained(
|
36 |
+
model_id,
|
37 |
+
torch_dtype=torch.bfloat16,
|
38 |
+
_attn_implementation='flash_attention_2'
|
39 |
+
).to('cuda')
|
40 |
+
elif device == 'xpu' and has_xpu:
|
41 |
+
# Use float32 on XPU to avoid bfloat16 layernorm issues
|
42 |
+
model = AutoModelForImageTextToText.from_pretrained(
|
43 |
+
model_id,
|
44 |
+
torch_dtype=torch.float32
|
45 |
+
).to('xpu')
|
46 |
+
else:
|
47 |
+
# Default to float32 on CPU
|
48 |
+
model = AutoModelForImageTextToText.from_pretrained(model_id).to('cpu')
|
49 |
+
model.eval()
|
50 |
+
model_cache.update({'model_id': model_id, 'processor': processor, 'model': model, 'device': device})
|
51 |
+
except Exception as e:
|
52 |
+
logging.error(f'Error loading model: {e}')
|
53 |
+
raise e
|
54 |
|
55 |
def extract_frames_from_video(video_path, max_frames=10):
|
56 |
"""Extract frames from video file for processing"""
|
57 |
+
if not os.path.exists(video_path):
|
58 |
+
raise FileNotFoundError(f"Video file not found: {video_path}")
|
59 |
+
|
60 |
+
# Validate video file
|
61 |
+
if not video_path.lower().endswith(('.mp4', '.avi', '.mov', '.mkv', '.webm')):
|
62 |
+
raise ValueError("Unsupported video format. Please use MP4, AVI, MOV, MKV, or WEBM.")
|
63 |
+
|
64 |
cap = cv2.VideoCapture(video_path)
|
65 |
+
if not cap.isOpened():
|
66 |
+
raise ValueError(f"Cannot open video file: {video_path}")
|
67 |
+
|
68 |
frames = []
|
69 |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
70 |
|
71 |
+
if frame_count == 0:
|
72 |
+
cap.release()
|
73 |
+
raise ValueError("Video file appears to be empty or corrupted")
|
74 |
+
|
75 |
# Calculate step size to extract evenly distributed frames
|
76 |
step = max(1, frame_count // max_frames)
|
77 |
|
|
|
95 |
def caption_frame(frame, model_id, interval_ms, sys_prompt, usr_prompt, device):
|
96 |
"""Caption a single frame (used for webcam streaming)"""
|
97 |
debug_msgs = []
|
98 |
+
try:
|
99 |
+
update_model(model_id, device)
|
100 |
+
processor = model_cache['processor']
|
101 |
+
model = model_cache['model']
|
102 |
|
103 |
+
# Control capture interval
|
104 |
+
time.sleep(interval_ms / 1000)
|
105 |
|
106 |
+
# Preprocess frame
|
107 |
+
t0 = time.time()
|
108 |
+
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
109 |
+
pil_img = Image.fromarray(rgb)
|
110 |
+
temp_path = 'frame.jpg'
|
111 |
+
pil_img.save(temp_path, format='JPEG', quality=50)
|
112 |
+
debug_msgs.append(f'Preprocess: {int((time.time()-t0)*1000)} ms')
|
113 |
|
114 |
+
# Prepare multimodal chat messages
|
115 |
+
messages = [
|
116 |
+
{'role': 'system', 'content': [{'type': 'text', 'text': sys_prompt}]},
|
117 |
+
{'role': 'user', 'content': [
|
118 |
+
{'type': 'image', 'url': temp_path},
|
119 |
+
{'type': 'text', 'text': usr_prompt}
|
120 |
+
]}
|
121 |
+
]
|
122 |
|
123 |
+
# Tokenize and encode
|
124 |
+
t1 = time.time()
|
125 |
+
inputs = processor.apply_chat_template(
|
126 |
+
messages,
|
127 |
+
add_generation_prompt=True,
|
128 |
+
tokenize=True,
|
129 |
+
return_dict=True,
|
130 |
+
return_tensors='pt'
|
131 |
+
)
|
132 |
+
# Move inputs to correct device and dtype (matching model parameters)
|
133 |
+
param_dtype = next(model.parameters()).dtype
|
134 |
+
cast_inputs = {}
|
135 |
+
for k, v in inputs.items():
|
136 |
+
if isinstance(v, torch.Tensor):
|
137 |
+
if v.dtype.is_floating_point:
|
138 |
+
# cast floating-point tensors to model's parameter dtype
|
139 |
+
cast_inputs[k] = v.to(device=model.device, dtype=param_dtype)
|
140 |
+
else:
|
141 |
+
# move integer/mask tensors without changing dtype
|
142 |
+
cast_inputs[k] = v.to(device=model.device)
|
143 |
else:
|
144 |
+
cast_inputs[k] = v
|
145 |
+
inputs = cast_inputs
|
146 |
+
debug_msgs.append(f'Tokenize: {int((time.time()-t1)*1000)} ms')
|
|
|
|
|
|
|
147 |
|
148 |
+
# Inference
|
149 |
+
t2 = time.time()
|
150 |
+
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=128)
|
151 |
+
debug_msgs.append(f'Inference: {int((time.time()-t2)*1000)} ms')
|
152 |
|
153 |
+
# Decode and strip history
|
154 |
+
t3 = time.time()
|
155 |
+
raw = processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
156 |
+
debug_msgs.append(f'Decode: {int((time.time()-t3)*1000)} ms')
|
157 |
+
if "Assistant:" in raw:
|
158 |
+
caption = raw.split("Assistant:")[-1].strip()
|
159 |
+
else:
|
160 |
+
lines = raw.splitlines()
|
161 |
+
caption = lines[-1].strip() if len(lines) > 1 else raw.strip()
|
162 |
|
163 |
+
# Clean up temp file
|
164 |
+
if os.path.exists(temp_path):
|
165 |
+
os.remove(temp_path)
|
166 |
+
|
167 |
+
return caption, '\n'.join(debug_msgs)
|
168 |
+
except Exception as e:
|
169 |
+
return f"Error: {str(e)}", '\n'.join(debug_msgs)
|
170 |
|
171 |
@spaces.GPU
|
172 |
def process_video_file(video_file, model_id, sys_prompt, usr_prompt, device, max_frames):
|
|
|
175 |
return "No video file uploaded", ""
|
176 |
|
177 |
debug_msgs = []
|
178 |
+
temp_files = [] # Track temporary files for cleanup
|
|
|
|
|
179 |
|
180 |
try:
|
181 |
+
update_model(model_id, device)
|
182 |
+
processor = model_cache['processor']
|
183 |
+
model = model_cache['model']
|
184 |
+
|
185 |
# Extract frames from video
|
186 |
t0 = time.time()
|
187 |
frames = extract_frames_from_video(video_file, max_frames)
|
|
|
198 |
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
199 |
pil_img = Image.fromarray(rgb)
|
200 |
temp_path = f'frame_{i}.jpg'
|
201 |
+
temp_files.append(temp_path) # Track for cleanup
|
202 |
pil_img.save(temp_path, format='JPEG', quality=50)
|
203 |
|
204 |
# Prepare multimodal chat messages
|
|
|
244 |
caption = lines[-1].strip() if len(lines) > 1 else raw.strip()
|
245 |
|
246 |
captions.append(f"Frame {i+1}: {caption}")
|
|
|
|
|
|
|
|
|
|
|
247 |
debug_msgs.append(f'Frame {i+1} processed in {int((time.time()-t1)*1000)} ms')
|
248 |
|
249 |
return '\n\n'.join(captions), '\n'.join(debug_msgs)
|
250 |
|
251 |
except Exception as e:
|
252 |
return f"Error processing video: {str(e)}", '\n'.join(debug_msgs)
|
253 |
+
finally:
|
254 |
+
# Clean up all temporary files
|
255 |
+
for temp_file in temp_files:
|
256 |
+
if os.path.exists(temp_file):
|
257 |
+
try:
|
258 |
+
os.remove(temp_file)
|
259 |
+
except Exception as cleanup_error:
|
260 |
+
logging.warning(f"Failed to cleanup {temp_file}: {cleanup_error}")
|
261 |
|
262 |
def toggle_input_mode(input_mode):
|
263 |
"""Toggle between webcam and video file input"""
|