yonigozlan HF Staff commited on
Commit
5b5416f
·
1 Parent(s): ba25bef

update for zero gpu

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +67 -116
  3. requirements.txt +1 -1
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 👀
4
  colorFrom: purple
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.42.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
4
  colorFrom: purple
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 5.47.2
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -1,15 +1,17 @@
1
  import colorsys
2
  import gc
 
3
  from typing import Optional
4
 
5
  import gradio as gr
6
  import numpy as np
 
7
  import torch
8
  from gradio.themes import Soft
9
  from PIL import Image, ImageDraw
10
 
11
  # Prefer local transformers in the workspace
12
- from transformers import Sam2VideoModel, Sam2VideoProcessor
13
 
14
 
15
  def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]:
@@ -52,10 +54,12 @@ def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], di
52
  cap.release()
53
  if fps_val and fps_val > 0:
54
  info["fps"] = float(fps_val)
55
- except Exception:
 
56
  pass
57
  return pil_frames, info
58
- except Exception:
 
59
  # Fallback to OpenCV
60
  try:
61
  import cv2 # type: ignore
@@ -115,7 +119,7 @@ def overlay_masks_on_frame(
115
 
116
 
117
  def get_device_and_dtype() -> tuple[str, torch.dtype]:
118
- device = "cuda" if torch.cuda.is_available() else "cpu"
119
  dtype = torch.bfloat16
120
  return device, dtype
121
 
@@ -127,9 +131,9 @@ class AppState:
127
  def reset(self):
128
  self.video_frames: list[Image.Image] = []
129
  self.inference_session = None
130
- self.model: Optional[Sam2VideoModel] = None
131
  self.processor: Optional[Sam2VideoProcessor] = None
132
- self.device: str = "cuda"
133
  self.dtype: torch.dtype = torch.bfloat16
134
  self.video_fps: float | None = None
135
  self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
@@ -153,6 +157,9 @@ class AppState:
153
  self.model_repo_id: str | None = None
154
  self.session_repo_id: str | None = None
155
 
 
 
 
156
  @property
157
  def num_frames(self) -> int:
158
  return len(self.video_frames)
@@ -168,29 +175,18 @@ def _model_repo_from_key(key: str) -> str:
168
  return mapping.get(key, mapping["base_plus"])
169
 
170
 
171
- def load_model_if_needed(GLOBAL_STATE: gr.State) -> tuple[Sam2VideoModel, Sam2VideoProcessor, str, torch.dtype]:
172
  desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
173
  if GLOBAL_STATE.model is not None and GLOBAL_STATE.processor is not None:
174
  if GLOBAL_STATE.model_repo_id == desired_repo:
175
  return GLOBAL_STATE.model, GLOBAL_STATE.processor, GLOBAL_STATE.device, GLOBAL_STATE.dtype
176
  # Different repo requested: dispose current and reload
177
- try:
178
- del GLOBAL_STATE.model
179
- except Exception:
180
- pass
181
- try:
182
- del GLOBAL_STATE.processor
183
- except Exception:
184
- pass
185
  GLOBAL_STATE.model = None
186
  GLOBAL_STATE.processor = None
187
  print(f"Loading model from {desired_repo}")
188
  device, dtype = get_device_and_dtype()
189
  # free up the gpu memory
190
- torch.cuda.empty_cache()
191
- gc.collect()
192
- print("device", device)
193
- model = Sam2VideoModel.from_pretrained(desired_repo)
194
  processor = Sam2VideoProcessor.from_pretrained(desired_repo)
195
  model.to(device, dtype=dtype)
196
 
@@ -216,24 +212,11 @@ def ensure_session_for_current_model(GLOBAL_STATE: gr.State) -> None:
216
  GLOBAL_STATE.clicks_by_frame_obj.clear()
217
  GLOBAL_STATE.boxes_by_frame_obj.clear()
218
  GLOBAL_STATE.composited_frames.clear()
219
- # Dispose previous session cleanly
220
- try:
221
- if GLOBAL_STATE.inference_session is not None:
222
- GLOBAL_STATE.inference_session.reset_inference_session()
223
- except Exception:
224
- pass
225
  GLOBAL_STATE.inference_session = None
226
- gc.collect()
227
- try:
228
- if torch.cuda.is_available():
229
- torch.cuda.empty_cache()
230
- except Exception:
231
- pass
232
  GLOBAL_STATE.inference_session = processor.init_video_session(
233
- video=GLOBAL_STATE.video_frames,
234
  inference_device=device,
235
  video_storage_device="cpu",
236
- torch_dtype=dtype,
237
  )
238
  GLOBAL_STATE.session_repo_id = desired_repo
239
 
@@ -267,43 +250,21 @@ def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppSt
267
  # Enforce max duration of 8 seconds (trim if longer)
268
  MAX_SECONDS = 8.0
269
  trimmed_note = ""
270
- fps_in = None
271
- if isinstance(info, dict) and info.get("fps"):
272
- try:
273
- fps_in = float(info["fps"]) or None
274
- except Exception:
275
- fps_in = None
276
- if fps_in is not None:
277
- max_frames_allowed = int(MAX_SECONDS * fps_in)
278
- if len(frames) > max_frames_allowed:
279
- frames = frames[:max_frames_allowed]
280
- trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
281
- if isinstance(info, dict):
282
- info["num_frames"] = len(frames)
283
- else:
284
- # Fallback when FPS unknown: assume ~30 FPS and cap to 240 frames (~8s)
285
- max_frames_allowed = 240
286
- if len(frames) > max_frames_allowed:
287
- frames = frames[:max_frames_allowed]
288
- trimmed_note = " (trimmed to 240 frames ~8s @30fps)"
289
- if isinstance(info, dict):
290
- info["num_frames"] = len(frames)
291
-
292
  GLOBAL_STATE.video_frames = frames
293
  # Try to capture original FPS if provided by loader
294
- GLOBAL_STATE.video_fps = None
295
- if isinstance(info, dict) and info.get("fps"):
296
- try:
297
- GLOBAL_STATE.video_fps = float(info["fps"]) or None
298
- except Exception:
299
- GLOBAL_STATE.video_fps = None
300
-
301
  # Initialize session
302
  inference_session = processor.init_video_session(
303
- video=frames,
304
  inference_device=device,
305
  video_storage_device="cpu",
306
- torch_dtype=dtype,
307
  )
308
  GLOBAL_STATE.inference_session = inference_session
309
 
@@ -414,6 +375,12 @@ def on_image_click(
414
  processor = state.processor
415
  model = state.model
416
  inference_session = state.inference_session
 
 
 
 
 
 
417
 
418
  if state.current_prompt_type == "Boxes":
419
  # Two-click box input
@@ -445,6 +412,7 @@ def on_image_click(
445
  obj_ids=int(obj_id),
446
  input_boxes=[[[x_min, y_min, x_max, y_max]]],
447
  clear_old_inputs=True, # For boxes, always clear old inputs
 
448
  )
449
 
450
  frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {})
@@ -467,6 +435,7 @@ def on_image_click(
467
  obj_ids=int(obj_id),
468
  input_points=[[[[int(x), int(y)]]]],
469
  input_labels=[[[int(label_int)]]],
 
470
  clear_old_inputs=bool(clear_old),
471
  )
472
 
@@ -478,12 +447,8 @@ def on_image_click(
478
  state.composited_frames.pop(int(frame_idx), None)
479
 
480
  # Forward on that frame
481
- device_type = "cuda" if state.device == "cuda" else "cpu"
482
- with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=state.dtype):
483
- outputs = model(
484
- inference_session=inference_session,
485
- frame_idx=int(frame_idx),
486
- )
487
 
488
  H = inference_session.video_height
489
  W = inference_session.video_width
@@ -509,31 +474,37 @@ def on_image_click(
509
  return update_frame_display(state, int(frame_idx))
510
 
511
 
 
512
  def propagate_masks(GLOBAL_STATE: gr.State):
513
  if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
514
- yield "Load a video first.", gr.update()
515
- return
516
 
517
- processor = GLOBAL_STATE.processor
518
- model = GLOBAL_STATE.model
519
- inference_session = GLOBAL_STATE.inference_session
 
 
 
 
520
 
521
  total = max(1, GLOBAL_STATE.num_frames)
522
  processed = 0
523
 
524
  # Initial status; no slider change yet
525
- yield f"Propagating masks: {processed}/{total}", gr.update()
526
 
527
- device_type = "cuda" if GLOBAL_STATE.device == "cuda" else "cpu"
528
  last_frame_idx = 0
529
- with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=GLOBAL_STATE.dtype):
530
- for sam2_video_output in model.propagate_in_video_iterator(inference_session):
 
 
 
 
531
  H = inference_session.video_height
532
  W = inference_session.video_width
533
  pred_masks = sam2_video_output.pred_masks.detach().cpu()
534
  video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
535
-
536
- frame_idx = int(sam2_video_output.frame_idx)
537
  last_frame_idx = frame_idx
538
  masks_for_frame: dict[int, np.ndarray] = {}
539
  obj_ids_order = list(inference_session.obj_ids)
@@ -546,16 +517,13 @@ def propagate_masks(GLOBAL_STATE: gr.State):
546
 
547
  processed += 1
548
  # Every 15th frame (or last), move slider to current frame to update preview via slider binding
549
- if processed % 15 == 0 or processed == total:
550
- yield f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
551
- else:
552
- yield f"Propagating masks: {processed}/{total}", gr.update()
553
 
554
  # Final status; ensure slider points to last processed frame
555
- yield (
556
- f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects.",
557
- gr.update(value=last_frame_idx),
558
- )
559
 
560
 
561
  def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str]:
@@ -581,11 +549,6 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
581
  pass
582
  GLOBAL_STATE.inference_session = None
583
  gc.collect()
584
- try:
585
- if torch.cuda.is_available():
586
- torch.cuda.empty_cache()
587
- except Exception:
588
- pass
589
  ensure_session_for_current_model(GLOBAL_STATE)
590
 
591
  # Keep current slider index if possible
@@ -786,29 +749,17 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
786
  out_path = "/tmp/sam2_playback.mp4"
787
  # Prefer imageio with PyAV/ffmpeg to respect exact fps
788
  try:
789
- import imageio.v3 as iio # type: ignore
790
 
791
- iio.imwrite(out_path, [fr[:, :, ::-1] for fr in frames_np], plugin="pyav", fps=fps)
 
 
 
 
792
  return out_path
793
- except Exception:
794
- # Fallbacks
795
- try:
796
- import imageio.v2 as imageio # type: ignore
797
-
798
- imageio.mimsave(out_path, [fr[:, :, ::-1] for fr in frames_np], fps=fps)
799
- return out_path
800
- except Exception:
801
- try:
802
- import cv2 # type: ignore
803
-
804
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
805
- writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
806
- for fr_bgr in frames_np:
807
- writer.write(fr_bgr)
808
- writer.release()
809
- return out_path
810
- except Exception as e:
811
- raise gr.Error(f"Failed to render video: {e}")
812
 
813
  render_btn.click(_render_video, inputs=[GLOBAL_STATE], outputs=[playback_video])
814
 
@@ -816,7 +767,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
816
  propagate_btn.click(
817
  propagate_masks,
818
  inputs=[GLOBAL_STATE],
819
- outputs=[propagate_status, frame_slider],
820
  )
821
 
822
  reset_btn.click(
 
1
  import colorsys
2
  import gc
3
+ from copy import deepcopy
4
  from typing import Optional
5
 
6
  import gradio as gr
7
  import numpy as np
8
+ import spaces
9
  import torch
10
  from gradio.themes import Soft
11
  from PIL import Image, ImageDraw
12
 
13
  # Prefer local transformers in the workspace
14
+ from transformers import AutoModel, Sam2VideoProcessor
15
 
16
 
17
  def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]:
 
54
  cap.release()
55
  if fps_val and fps_val > 0:
56
  info["fps"] = float(fps_val)
57
+ except Exception as e:
58
+ print(f"Failed to render video with cv2: {e}")
59
  pass
60
  return pil_frames, info
61
+ except Exception as e:
62
+ print(f"Failed to load video with transformers.video_utils: {e}")
63
  # Fallback to OpenCV
64
  try:
65
  import cv2 # type: ignore
 
119
 
120
 
121
  def get_device_and_dtype() -> tuple[str, torch.dtype]:
122
+ device = "cpu"
123
  dtype = torch.bfloat16
124
  return device, dtype
125
 
 
131
  def reset(self):
132
  self.video_frames: list[Image.Image] = []
133
  self.inference_session = None
134
+ self.model: Optional[AutoModel] = None
135
  self.processor: Optional[Sam2VideoProcessor] = None
136
+ self.device: str = "cpu"
137
  self.dtype: torch.dtype = torch.bfloat16
138
  self.video_fps: float | None = None
139
  self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
 
157
  self.model_repo_id: str | None = None
158
  self.session_repo_id: str | None = None
159
 
160
+ def __repr__(self):
161
+ return f"AppState(video_frames={self.video_frames}, inference_session={self.inference_session is not None}, model={self.model is not None}, processor={self.processor is not None}, device={self.device}, dtype={self.dtype}, video_fps={self.video_fps}, masks_by_frame={self.masks_by_frame}, color_by_obj={self.color_by_obj}, clicks_by_frame_obj={self.clicks_by_frame_obj}, boxes_by_frame_obj={self.boxes_by_frame_obj}, composited_frames={self.composited_frames}, current_frame_idx={self.current_frame_idx}, current_obj_id={self.current_obj_id}, current_label={self.current_label}, current_clear_old={self.current_clear_old}, current_prompt_type={self.current_prompt_type}, pending_box_start={self.pending_box_start}, pending_box_start_frame_idx={self.pending_box_start_frame_idx}, pending_box_start_obj_id={self.pending_box_start_obj_id}, is_switching_model={self.is_switching_model}, model_repo_key={self.model_repo_key}, model_repo_id={self.model_repo_id}, session_repo_id={self.session_repo_id})"
162
+
163
  @property
164
  def num_frames(self) -> int:
165
  return len(self.video_frames)
 
175
  return mapping.get(key, mapping["base_plus"])
176
 
177
 
178
+ def load_model_if_needed(GLOBAL_STATE: gr.State) -> tuple[AutoModel, Sam2VideoProcessor, str, torch.dtype]:
179
  desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
180
  if GLOBAL_STATE.model is not None and GLOBAL_STATE.processor is not None:
181
  if GLOBAL_STATE.model_repo_id == desired_repo:
182
  return GLOBAL_STATE.model, GLOBAL_STATE.processor, GLOBAL_STATE.device, GLOBAL_STATE.dtype
183
  # Different repo requested: dispose current and reload
 
 
 
 
 
 
 
 
184
  GLOBAL_STATE.model = None
185
  GLOBAL_STATE.processor = None
186
  print(f"Loading model from {desired_repo}")
187
  device, dtype = get_device_and_dtype()
188
  # free up the gpu memory
189
+ model = AutoModel.from_pretrained(desired_repo)
 
 
 
190
  processor = Sam2VideoProcessor.from_pretrained(desired_repo)
191
  model.to(device, dtype=dtype)
192
 
 
212
  GLOBAL_STATE.clicks_by_frame_obj.clear()
213
  GLOBAL_STATE.boxes_by_frame_obj.clear()
214
  GLOBAL_STATE.composited_frames.clear()
 
 
 
 
 
 
215
  GLOBAL_STATE.inference_session = None
 
 
 
 
 
 
216
  GLOBAL_STATE.inference_session = processor.init_video_session(
 
217
  inference_device=device,
218
  video_storage_device="cpu",
219
+ dtype=dtype,
220
  )
221
  GLOBAL_STATE.session_repo_id = desired_repo
222
 
 
250
  # Enforce max duration of 8 seconds (trim if longer)
251
  MAX_SECONDS = 8.0
252
  trimmed_note = ""
253
+ fps_in = info.get("fps")
254
+ max_frames_allowed = int(MAX_SECONDS * fps_in)
255
+ if len(frames) > max_frames_allowed:
256
+ frames = frames[:max_frames_allowed]
257
+ trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
258
+ if isinstance(info, dict):
259
+ info["num_frames"] = len(frames)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  GLOBAL_STATE.video_frames = frames
261
  # Try to capture original FPS if provided by loader
262
+ GLOBAL_STATE.video_fps = float(fps_in)
 
 
 
 
 
 
263
  # Initialize session
264
  inference_session = processor.init_video_session(
 
265
  inference_device=device,
266
  video_storage_device="cpu",
267
+ dtype=dtype,
268
  )
269
  GLOBAL_STATE.inference_session = inference_session
270
 
 
375
  processor = state.processor
376
  model = state.model
377
  inference_session = state.inference_session
378
+ original_size = None
379
+ pixel_values = None
380
+ if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames:
381
+ inputs = processor(images=state.video_frames[frame_idx], device=state.device, return_tensors="pt")
382
+ original_size = inputs.original_sizes[0]
383
+ pixel_values = inputs.pixel_values[0]
384
 
385
  if state.current_prompt_type == "Boxes":
386
  # Two-click box input
 
412
  obj_ids=int(obj_id),
413
  input_boxes=[[[x_min, y_min, x_max, y_max]]],
414
  clear_old_inputs=True, # For boxes, always clear old inputs
415
+ original_size=original_size,
416
  )
417
 
418
  frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {})
 
435
  obj_ids=int(obj_id),
436
  input_points=[[[[int(x), int(y)]]]],
437
  input_labels=[[[int(label_int)]]],
438
+ original_size=original_size,
439
  clear_old_inputs=bool(clear_old),
440
  )
441
 
 
447
  state.composited_frames.pop(int(frame_idx), None)
448
 
449
  # Forward on that frame
450
+ with torch.inference_mode():
451
+ outputs = model(inference_session=inference_session, frame=pixel_values, frame_idx=int(frame_idx))
 
 
 
 
452
 
453
  H = inference_session.video_height
454
  W = inference_session.video_width
 
474
  return update_frame_display(state, int(frame_idx))
475
 
476
 
477
+ @spaces.GPU()
478
  def propagate_masks(GLOBAL_STATE: gr.State):
479
  if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
480
+ # yield GLOBAL_STATE, "Load a video first.", gr.update()
481
+ return GLOBAL_STATE, "Load a video first.", gr.update()
482
 
483
+ processor = deepcopy(GLOBAL_STATE.processor)
484
+ model = deepcopy(GLOBAL_STATE.model)
485
+ inference_session = deepcopy(GLOBAL_STATE.inference_session)
486
+ # set inference device to cuda to use zero gpu
487
+ inference_session.inference_device = "cuda"
488
+ inference_session.cache.inference_device = "cuda"
489
+ model.to("cuda")
490
 
491
  total = max(1, GLOBAL_STATE.num_frames)
492
  processed = 0
493
 
494
  # Initial status; no slider change yet
495
+ yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update()
496
 
 
497
  last_frame_idx = 0
498
+ with torch.inference_mode():
499
+ for frame_idx, frame in enumerate(GLOBAL_STATE.video_frames):
500
+ pixel_values = None
501
+ if inference_session.processed_frames is None or frame_idx not in inference_session.processed_frames:
502
+ pixel_values = processor(images=frame, device="cuda", return_tensors="pt").pixel_values[0]
503
+ sam2_video_output = model(inference_session=inference_session, frame=pixel_values, frame_idx=frame_idx)
504
  H = inference_session.video_height
505
  W = inference_session.video_width
506
  pred_masks = sam2_video_output.pred_masks.detach().cpu()
507
  video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
 
 
508
  last_frame_idx = frame_idx
509
  masks_for_frame: dict[int, np.ndarray] = {}
510
  obj_ids_order = list(inference_session.obj_ids)
 
517
 
518
  processed += 1
519
  # Every 15th frame (or last), move slider to current frame to update preview via slider binding
520
+ if processed % 30 == 0 or processed == total:
521
+ yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
522
+
523
+ text = f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects."
524
 
525
  # Final status; ensure slider points to last processed frame
526
+ yield GLOBAL_STATE, text, gr.update(value=last_frame_idx)
 
 
 
527
 
528
 
529
  def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str]:
 
549
  pass
550
  GLOBAL_STATE.inference_session = None
551
  gc.collect()
 
 
 
 
 
552
  ensure_session_for_current_model(GLOBAL_STATE)
553
 
554
  # Keep current slider index if possible
 
749
  out_path = "/tmp/sam2_playback.mp4"
750
  # Prefer imageio with PyAV/ffmpeg to respect exact fps
751
  try:
752
+ import cv2 # type: ignore
753
 
754
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
755
+ writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
756
+ for fr_bgr in frames_np:
757
+ writer.write(fr_bgr)
758
+ writer.release()
759
  return out_path
760
+ except Exception as e:
761
+ print(f"Failed to render video with cv2: {e}")
762
+ raise gr.Error(f"Failed to render video: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
763
 
764
  render_btn.click(_render_video, inputs=[GLOBAL_STATE], outputs=[playback_video])
765
 
 
767
  propagate_btn.click(
768
  propagate_masks,
769
  inputs=[GLOBAL_STATE],
770
+ outputs=[GLOBAL_STATE, propagate_status, frame_slider],
771
  )
772
 
773
  reset_btn.click(
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  gradio
2
- git+https://github.com/SangbumChoi/transformers.git@sam2
3
  torch
4
  torchvision
5
  pillow
 
1
  gradio
2
+ git+https://github.com/yonigozlan/transformers.git@add-edgetam
3
  torch
4
  torchvision
5
  pillow