Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
5b5416f
1
Parent(s):
ba25bef
update for zero gpu
Browse files- README.md +1 -1
- app.py +67 -116
- requirements.txt +1 -1
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: 👀
|
|
4 |
colorFrom: purple
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
|
|
4 |
colorFrom: purple
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.47.2
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
app.py
CHANGED
@@ -1,15 +1,17 @@
|
|
1 |
import colorsys
|
2 |
import gc
|
|
|
3 |
from typing import Optional
|
4 |
|
5 |
import gradio as gr
|
6 |
import numpy as np
|
|
|
7 |
import torch
|
8 |
from gradio.themes import Soft
|
9 |
from PIL import Image, ImageDraw
|
10 |
|
11 |
# Prefer local transformers in the workspace
|
12 |
-
from transformers import
|
13 |
|
14 |
|
15 |
def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]:
|
@@ -52,10 +54,12 @@ def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], di
|
|
52 |
cap.release()
|
53 |
if fps_val and fps_val > 0:
|
54 |
info["fps"] = float(fps_val)
|
55 |
-
except Exception:
|
|
|
56 |
pass
|
57 |
return pil_frames, info
|
58 |
-
except Exception:
|
|
|
59 |
# Fallback to OpenCV
|
60 |
try:
|
61 |
import cv2 # type: ignore
|
@@ -115,7 +119,7 @@ def overlay_masks_on_frame(
|
|
115 |
|
116 |
|
117 |
def get_device_and_dtype() -> tuple[str, torch.dtype]:
|
118 |
-
device = "
|
119 |
dtype = torch.bfloat16
|
120 |
return device, dtype
|
121 |
|
@@ -127,9 +131,9 @@ class AppState:
|
|
127 |
def reset(self):
|
128 |
self.video_frames: list[Image.Image] = []
|
129 |
self.inference_session = None
|
130 |
-
self.model: Optional[
|
131 |
self.processor: Optional[Sam2VideoProcessor] = None
|
132 |
-
self.device: str = "
|
133 |
self.dtype: torch.dtype = torch.bfloat16
|
134 |
self.video_fps: float | None = None
|
135 |
self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
|
@@ -153,6 +157,9 @@ class AppState:
|
|
153 |
self.model_repo_id: str | None = None
|
154 |
self.session_repo_id: str | None = None
|
155 |
|
|
|
|
|
|
|
156 |
@property
|
157 |
def num_frames(self) -> int:
|
158 |
return len(self.video_frames)
|
@@ -168,29 +175,18 @@ def _model_repo_from_key(key: str) -> str:
|
|
168 |
return mapping.get(key, mapping["base_plus"])
|
169 |
|
170 |
|
171 |
-
def load_model_if_needed(GLOBAL_STATE: gr.State) -> tuple[
|
172 |
desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
|
173 |
if GLOBAL_STATE.model is not None and GLOBAL_STATE.processor is not None:
|
174 |
if GLOBAL_STATE.model_repo_id == desired_repo:
|
175 |
return GLOBAL_STATE.model, GLOBAL_STATE.processor, GLOBAL_STATE.device, GLOBAL_STATE.dtype
|
176 |
# Different repo requested: dispose current and reload
|
177 |
-
try:
|
178 |
-
del GLOBAL_STATE.model
|
179 |
-
except Exception:
|
180 |
-
pass
|
181 |
-
try:
|
182 |
-
del GLOBAL_STATE.processor
|
183 |
-
except Exception:
|
184 |
-
pass
|
185 |
GLOBAL_STATE.model = None
|
186 |
GLOBAL_STATE.processor = None
|
187 |
print(f"Loading model from {desired_repo}")
|
188 |
device, dtype = get_device_and_dtype()
|
189 |
# free up the gpu memory
|
190 |
-
|
191 |
-
gc.collect()
|
192 |
-
print("device", device)
|
193 |
-
model = Sam2VideoModel.from_pretrained(desired_repo)
|
194 |
processor = Sam2VideoProcessor.from_pretrained(desired_repo)
|
195 |
model.to(device, dtype=dtype)
|
196 |
|
@@ -216,24 +212,11 @@ def ensure_session_for_current_model(GLOBAL_STATE: gr.State) -> None:
|
|
216 |
GLOBAL_STATE.clicks_by_frame_obj.clear()
|
217 |
GLOBAL_STATE.boxes_by_frame_obj.clear()
|
218 |
GLOBAL_STATE.composited_frames.clear()
|
219 |
-
# Dispose previous session cleanly
|
220 |
-
try:
|
221 |
-
if GLOBAL_STATE.inference_session is not None:
|
222 |
-
GLOBAL_STATE.inference_session.reset_inference_session()
|
223 |
-
except Exception:
|
224 |
-
pass
|
225 |
GLOBAL_STATE.inference_session = None
|
226 |
-
gc.collect()
|
227 |
-
try:
|
228 |
-
if torch.cuda.is_available():
|
229 |
-
torch.cuda.empty_cache()
|
230 |
-
except Exception:
|
231 |
-
pass
|
232 |
GLOBAL_STATE.inference_session = processor.init_video_session(
|
233 |
-
video=GLOBAL_STATE.video_frames,
|
234 |
inference_device=device,
|
235 |
video_storage_device="cpu",
|
236 |
-
|
237 |
)
|
238 |
GLOBAL_STATE.session_repo_id = desired_repo
|
239 |
|
@@ -267,43 +250,21 @@ def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppSt
|
|
267 |
# Enforce max duration of 8 seconds (trim if longer)
|
268 |
MAX_SECONDS = 8.0
|
269 |
trimmed_note = ""
|
270 |
-
fps_in =
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
max_frames_allowed = int(MAX_SECONDS * fps_in)
|
278 |
-
if len(frames) > max_frames_allowed:
|
279 |
-
frames = frames[:max_frames_allowed]
|
280 |
-
trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
|
281 |
-
if isinstance(info, dict):
|
282 |
-
info["num_frames"] = len(frames)
|
283 |
-
else:
|
284 |
-
# Fallback when FPS unknown: assume ~30 FPS and cap to 240 frames (~8s)
|
285 |
-
max_frames_allowed = 240
|
286 |
-
if len(frames) > max_frames_allowed:
|
287 |
-
frames = frames[:max_frames_allowed]
|
288 |
-
trimmed_note = " (trimmed to 240 frames ~8s @30fps)"
|
289 |
-
if isinstance(info, dict):
|
290 |
-
info["num_frames"] = len(frames)
|
291 |
-
|
292 |
GLOBAL_STATE.video_frames = frames
|
293 |
# Try to capture original FPS if provided by loader
|
294 |
-
GLOBAL_STATE.video_fps =
|
295 |
-
if isinstance(info, dict) and info.get("fps"):
|
296 |
-
try:
|
297 |
-
GLOBAL_STATE.video_fps = float(info["fps"]) or None
|
298 |
-
except Exception:
|
299 |
-
GLOBAL_STATE.video_fps = None
|
300 |
-
|
301 |
# Initialize session
|
302 |
inference_session = processor.init_video_session(
|
303 |
-
video=frames,
|
304 |
inference_device=device,
|
305 |
video_storage_device="cpu",
|
306 |
-
|
307 |
)
|
308 |
GLOBAL_STATE.inference_session = inference_session
|
309 |
|
@@ -414,6 +375,12 @@ def on_image_click(
|
|
414 |
processor = state.processor
|
415 |
model = state.model
|
416 |
inference_session = state.inference_session
|
|
|
|
|
|
|
|
|
|
|
|
|
417 |
|
418 |
if state.current_prompt_type == "Boxes":
|
419 |
# Two-click box input
|
@@ -445,6 +412,7 @@ def on_image_click(
|
|
445 |
obj_ids=int(obj_id),
|
446 |
input_boxes=[[[x_min, y_min, x_max, y_max]]],
|
447 |
clear_old_inputs=True, # For boxes, always clear old inputs
|
|
|
448 |
)
|
449 |
|
450 |
frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {})
|
@@ -467,6 +435,7 @@ def on_image_click(
|
|
467 |
obj_ids=int(obj_id),
|
468 |
input_points=[[[[int(x), int(y)]]]],
|
469 |
input_labels=[[[int(label_int)]]],
|
|
|
470 |
clear_old_inputs=bool(clear_old),
|
471 |
)
|
472 |
|
@@ -478,12 +447,8 @@ def on_image_click(
|
|
478 |
state.composited_frames.pop(int(frame_idx), None)
|
479 |
|
480 |
# Forward on that frame
|
481 |
-
|
482 |
-
|
483 |
-
outputs = model(
|
484 |
-
inference_session=inference_session,
|
485 |
-
frame_idx=int(frame_idx),
|
486 |
-
)
|
487 |
|
488 |
H = inference_session.video_height
|
489 |
W = inference_session.video_width
|
@@ -509,31 +474,37 @@ def on_image_click(
|
|
509 |
return update_frame_display(state, int(frame_idx))
|
510 |
|
511 |
|
|
|
512 |
def propagate_masks(GLOBAL_STATE: gr.State):
|
513 |
if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
|
514 |
-
yield "Load a video first.", gr.update()
|
515 |
-
return
|
516 |
|
517 |
-
processor = GLOBAL_STATE.processor
|
518 |
-
model = GLOBAL_STATE.model
|
519 |
-
inference_session = GLOBAL_STATE.inference_session
|
|
|
|
|
|
|
|
|
520 |
|
521 |
total = max(1, GLOBAL_STATE.num_frames)
|
522 |
processed = 0
|
523 |
|
524 |
# Initial status; no slider change yet
|
525 |
-
yield f"Propagating masks: {processed}/{total}", gr.update()
|
526 |
|
527 |
-
device_type = "cuda" if GLOBAL_STATE.device == "cuda" else "cpu"
|
528 |
last_frame_idx = 0
|
529 |
-
with torch.inference_mode()
|
530 |
-
for
|
|
|
|
|
|
|
|
|
531 |
H = inference_session.video_height
|
532 |
W = inference_session.video_width
|
533 |
pred_masks = sam2_video_output.pred_masks.detach().cpu()
|
534 |
video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
|
535 |
-
|
536 |
-
frame_idx = int(sam2_video_output.frame_idx)
|
537 |
last_frame_idx = frame_idx
|
538 |
masks_for_frame: dict[int, np.ndarray] = {}
|
539 |
obj_ids_order = list(inference_session.obj_ids)
|
@@ -546,16 +517,13 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
546 |
|
547 |
processed += 1
|
548 |
# Every 15th frame (or last), move slider to current frame to update preview via slider binding
|
549 |
-
if processed %
|
550 |
-
yield f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
|
551 |
-
|
552 |
-
|
553 |
|
554 |
# Final status; ensure slider points to last processed frame
|
555 |
-
yield (
|
556 |
-
f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects.",
|
557 |
-
gr.update(value=last_frame_idx),
|
558 |
-
)
|
559 |
|
560 |
|
561 |
def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str]:
|
@@ -581,11 +549,6 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
|
|
581 |
pass
|
582 |
GLOBAL_STATE.inference_session = None
|
583 |
gc.collect()
|
584 |
-
try:
|
585 |
-
if torch.cuda.is_available():
|
586 |
-
torch.cuda.empty_cache()
|
587 |
-
except Exception:
|
588 |
-
pass
|
589 |
ensure_session_for_current_model(GLOBAL_STATE)
|
590 |
|
591 |
# Keep current slider index if possible
|
@@ -786,29 +749,17 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
786 |
out_path = "/tmp/sam2_playback.mp4"
|
787 |
# Prefer imageio with PyAV/ffmpeg to respect exact fps
|
788 |
try:
|
789 |
-
import
|
790 |
|
791 |
-
|
|
|
|
|
|
|
|
|
792 |
return out_path
|
793 |
-
except Exception:
|
794 |
-
|
795 |
-
|
796 |
-
import imageio.v2 as imageio # type: ignore
|
797 |
-
|
798 |
-
imageio.mimsave(out_path, [fr[:, :, ::-1] for fr in frames_np], fps=fps)
|
799 |
-
return out_path
|
800 |
-
except Exception:
|
801 |
-
try:
|
802 |
-
import cv2 # type: ignore
|
803 |
-
|
804 |
-
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
805 |
-
writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
|
806 |
-
for fr_bgr in frames_np:
|
807 |
-
writer.write(fr_bgr)
|
808 |
-
writer.release()
|
809 |
-
return out_path
|
810 |
-
except Exception as e:
|
811 |
-
raise gr.Error(f"Failed to render video: {e}")
|
812 |
|
813 |
render_btn.click(_render_video, inputs=[GLOBAL_STATE], outputs=[playback_video])
|
814 |
|
@@ -816,7 +767,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
816 |
propagate_btn.click(
|
817 |
propagate_masks,
|
818 |
inputs=[GLOBAL_STATE],
|
819 |
-
outputs=[propagate_status, frame_slider],
|
820 |
)
|
821 |
|
822 |
reset_btn.click(
|
|
|
1 |
import colorsys
|
2 |
import gc
|
3 |
+
from copy import deepcopy
|
4 |
from typing import Optional
|
5 |
|
6 |
import gradio as gr
|
7 |
import numpy as np
|
8 |
+
import spaces
|
9 |
import torch
|
10 |
from gradio.themes import Soft
|
11 |
from PIL import Image, ImageDraw
|
12 |
|
13 |
# Prefer local transformers in the workspace
|
14 |
+
from transformers import AutoModel, Sam2VideoProcessor
|
15 |
|
16 |
|
17 |
def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]:
|
|
|
54 |
cap.release()
|
55 |
if fps_val and fps_val > 0:
|
56 |
info["fps"] = float(fps_val)
|
57 |
+
except Exception as e:
|
58 |
+
print(f"Failed to render video with cv2: {e}")
|
59 |
pass
|
60 |
return pil_frames, info
|
61 |
+
except Exception as e:
|
62 |
+
print(f"Failed to load video with transformers.video_utils: {e}")
|
63 |
# Fallback to OpenCV
|
64 |
try:
|
65 |
import cv2 # type: ignore
|
|
|
119 |
|
120 |
|
121 |
def get_device_and_dtype() -> tuple[str, torch.dtype]:
|
122 |
+
device = "cpu"
|
123 |
dtype = torch.bfloat16
|
124 |
return device, dtype
|
125 |
|
|
|
131 |
def reset(self):
|
132 |
self.video_frames: list[Image.Image] = []
|
133 |
self.inference_session = None
|
134 |
+
self.model: Optional[AutoModel] = None
|
135 |
self.processor: Optional[Sam2VideoProcessor] = None
|
136 |
+
self.device: str = "cpu"
|
137 |
self.dtype: torch.dtype = torch.bfloat16
|
138 |
self.video_fps: float | None = None
|
139 |
self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
|
|
|
157 |
self.model_repo_id: str | None = None
|
158 |
self.session_repo_id: str | None = None
|
159 |
|
160 |
+
def __repr__(self):
|
161 |
+
return f"AppState(video_frames={self.video_frames}, inference_session={self.inference_session is not None}, model={self.model is not None}, processor={self.processor is not None}, device={self.device}, dtype={self.dtype}, video_fps={self.video_fps}, masks_by_frame={self.masks_by_frame}, color_by_obj={self.color_by_obj}, clicks_by_frame_obj={self.clicks_by_frame_obj}, boxes_by_frame_obj={self.boxes_by_frame_obj}, composited_frames={self.composited_frames}, current_frame_idx={self.current_frame_idx}, current_obj_id={self.current_obj_id}, current_label={self.current_label}, current_clear_old={self.current_clear_old}, current_prompt_type={self.current_prompt_type}, pending_box_start={self.pending_box_start}, pending_box_start_frame_idx={self.pending_box_start_frame_idx}, pending_box_start_obj_id={self.pending_box_start_obj_id}, is_switching_model={self.is_switching_model}, model_repo_key={self.model_repo_key}, model_repo_id={self.model_repo_id}, session_repo_id={self.session_repo_id})"
|
162 |
+
|
163 |
@property
|
164 |
def num_frames(self) -> int:
|
165 |
return len(self.video_frames)
|
|
|
175 |
return mapping.get(key, mapping["base_plus"])
|
176 |
|
177 |
|
178 |
+
def load_model_if_needed(GLOBAL_STATE: gr.State) -> tuple[AutoModel, Sam2VideoProcessor, str, torch.dtype]:
|
179 |
desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
|
180 |
if GLOBAL_STATE.model is not None and GLOBAL_STATE.processor is not None:
|
181 |
if GLOBAL_STATE.model_repo_id == desired_repo:
|
182 |
return GLOBAL_STATE.model, GLOBAL_STATE.processor, GLOBAL_STATE.device, GLOBAL_STATE.dtype
|
183 |
# Different repo requested: dispose current and reload
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
GLOBAL_STATE.model = None
|
185 |
GLOBAL_STATE.processor = None
|
186 |
print(f"Loading model from {desired_repo}")
|
187 |
device, dtype = get_device_and_dtype()
|
188 |
# free up the gpu memory
|
189 |
+
model = AutoModel.from_pretrained(desired_repo)
|
|
|
|
|
|
|
190 |
processor = Sam2VideoProcessor.from_pretrained(desired_repo)
|
191 |
model.to(device, dtype=dtype)
|
192 |
|
|
|
212 |
GLOBAL_STATE.clicks_by_frame_obj.clear()
|
213 |
GLOBAL_STATE.boxes_by_frame_obj.clear()
|
214 |
GLOBAL_STATE.composited_frames.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
GLOBAL_STATE.inference_session = None
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
GLOBAL_STATE.inference_session = processor.init_video_session(
|
|
|
217 |
inference_device=device,
|
218 |
video_storage_device="cpu",
|
219 |
+
dtype=dtype,
|
220 |
)
|
221 |
GLOBAL_STATE.session_repo_id = desired_repo
|
222 |
|
|
|
250 |
# Enforce max duration of 8 seconds (trim if longer)
|
251 |
MAX_SECONDS = 8.0
|
252 |
trimmed_note = ""
|
253 |
+
fps_in = info.get("fps")
|
254 |
+
max_frames_allowed = int(MAX_SECONDS * fps_in)
|
255 |
+
if len(frames) > max_frames_allowed:
|
256 |
+
frames = frames[:max_frames_allowed]
|
257 |
+
trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
|
258 |
+
if isinstance(info, dict):
|
259 |
+
info["num_frames"] = len(frames)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
GLOBAL_STATE.video_frames = frames
|
261 |
# Try to capture original FPS if provided by loader
|
262 |
+
GLOBAL_STATE.video_fps = float(fps_in)
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
# Initialize session
|
264 |
inference_session = processor.init_video_session(
|
|
|
265 |
inference_device=device,
|
266 |
video_storage_device="cpu",
|
267 |
+
dtype=dtype,
|
268 |
)
|
269 |
GLOBAL_STATE.inference_session = inference_session
|
270 |
|
|
|
375 |
processor = state.processor
|
376 |
model = state.model
|
377 |
inference_session = state.inference_session
|
378 |
+
original_size = None
|
379 |
+
pixel_values = None
|
380 |
+
if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames:
|
381 |
+
inputs = processor(images=state.video_frames[frame_idx], device=state.device, return_tensors="pt")
|
382 |
+
original_size = inputs.original_sizes[0]
|
383 |
+
pixel_values = inputs.pixel_values[0]
|
384 |
|
385 |
if state.current_prompt_type == "Boxes":
|
386 |
# Two-click box input
|
|
|
412 |
obj_ids=int(obj_id),
|
413 |
input_boxes=[[[x_min, y_min, x_max, y_max]]],
|
414 |
clear_old_inputs=True, # For boxes, always clear old inputs
|
415 |
+
original_size=original_size,
|
416 |
)
|
417 |
|
418 |
frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {})
|
|
|
435 |
obj_ids=int(obj_id),
|
436 |
input_points=[[[[int(x), int(y)]]]],
|
437 |
input_labels=[[[int(label_int)]]],
|
438 |
+
original_size=original_size,
|
439 |
clear_old_inputs=bool(clear_old),
|
440 |
)
|
441 |
|
|
|
447 |
state.composited_frames.pop(int(frame_idx), None)
|
448 |
|
449 |
# Forward on that frame
|
450 |
+
with torch.inference_mode():
|
451 |
+
outputs = model(inference_session=inference_session, frame=pixel_values, frame_idx=int(frame_idx))
|
|
|
|
|
|
|
|
|
452 |
|
453 |
H = inference_session.video_height
|
454 |
W = inference_session.video_width
|
|
|
474 |
return update_frame_display(state, int(frame_idx))
|
475 |
|
476 |
|
477 |
+
@spaces.GPU()
|
478 |
def propagate_masks(GLOBAL_STATE: gr.State):
|
479 |
if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
|
480 |
+
# yield GLOBAL_STATE, "Load a video first.", gr.update()
|
481 |
+
return GLOBAL_STATE, "Load a video first.", gr.update()
|
482 |
|
483 |
+
processor = deepcopy(GLOBAL_STATE.processor)
|
484 |
+
model = deepcopy(GLOBAL_STATE.model)
|
485 |
+
inference_session = deepcopy(GLOBAL_STATE.inference_session)
|
486 |
+
# set inference device to cuda to use zero gpu
|
487 |
+
inference_session.inference_device = "cuda"
|
488 |
+
inference_session.cache.inference_device = "cuda"
|
489 |
+
model.to("cuda")
|
490 |
|
491 |
total = max(1, GLOBAL_STATE.num_frames)
|
492 |
processed = 0
|
493 |
|
494 |
# Initial status; no slider change yet
|
495 |
+
yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update()
|
496 |
|
|
|
497 |
last_frame_idx = 0
|
498 |
+
with torch.inference_mode():
|
499 |
+
for frame_idx, frame in enumerate(GLOBAL_STATE.video_frames):
|
500 |
+
pixel_values = None
|
501 |
+
if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames:
|
502 |
+
pixel_values = processor(images=frame, device="cuda", return_tensors="pt").pixel_values[0]
|
503 |
+
sam2_video_output = model(inference_session=inference_session, frame=pixel_values, frame_idx=frame_idx)
|
504 |
H = inference_session.video_height
|
505 |
W = inference_session.video_width
|
506 |
pred_masks = sam2_video_output.pred_masks.detach().cpu()
|
507 |
video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
|
|
|
|
|
508 |
last_frame_idx = frame_idx
|
509 |
masks_for_frame: dict[int, np.ndarray] = {}
|
510 |
obj_ids_order = list(inference_session.obj_ids)
|
|
|
517 |
|
518 |
processed += 1
|
519 |
# Every 15th frame (or last), move slider to current frame to update preview via slider binding
|
520 |
+
if processed % 30 == 0 or processed == total:
|
521 |
+
yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
|
522 |
+
|
523 |
+
text = f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects."
|
524 |
|
525 |
# Final status; ensure slider points to last processed frame
|
526 |
+
yield GLOBAL_STATE, text, gr.update(value=last_frame_idx)
|
|
|
|
|
|
|
527 |
|
528 |
|
529 |
def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str]:
|
|
|
549 |
pass
|
550 |
GLOBAL_STATE.inference_session = None
|
551 |
gc.collect()
|
|
|
|
|
|
|
|
|
|
|
552 |
ensure_session_for_current_model(GLOBAL_STATE)
|
553 |
|
554 |
# Keep current slider index if possible
|
|
|
749 |
out_path = "/tmp/sam2_playback.mp4"
|
750 |
# Prefer imageio with PyAV/ffmpeg to respect exact fps
|
751 |
try:
|
752 |
+
import cv2 # type: ignore
|
753 |
|
754 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
755 |
+
writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
|
756 |
+
for fr_bgr in frames_np:
|
757 |
+
writer.write(fr_bgr)
|
758 |
+
writer.release()
|
759 |
return out_path
|
760 |
+
except Exception as e:
|
761 |
+
print(f"Failed to render video with cv2: {e}")
|
762 |
+
raise gr.Error(f"Failed to render video: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
763 |
|
764 |
render_btn.click(_render_video, inputs=[GLOBAL_STATE], outputs=[playback_video])
|
765 |
|
|
|
767 |
propagate_btn.click(
|
768 |
propagate_masks,
|
769 |
inputs=[GLOBAL_STATE],
|
770 |
+
outputs=[GLOBAL_STATE, propagate_status, frame_slider],
|
771 |
)
|
772 |
|
773 |
reset_btn.click(
|
requirements.txt
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
gradio
|
2 |
-
git+https://github.com/
|
3 |
torch
|
4 |
torchvision
|
5 |
pillow
|
|
|
1 |
gradio
|
2 |
+
git+https://github.com/yonigozlan/transformers.git@add-edgetam
|
3 |
torch
|
4 |
torchvision
|
5 |
pillow
|