yonigozlan HF Staff commited on
Commit
942c318
·
1 Parent(s): 5b5416f

load video with cv2

Browse files
Files changed (1) hide show
  1. app.py +29 -64
app.py CHANGED
@@ -3,6 +3,7 @@ import gc
3
  from copy import deepcopy
4
  from typing import Optional
5
 
 
6
  import gradio as gr
7
  import numpy as np
8
  import spaces
@@ -10,7 +11,6 @@ import torch
10
  from gradio.themes import Soft
11
  from PIL import Image, ImageDraw
12
 
13
- # Prefer local transformers in the workspace
14
  from transformers import AutoModel, Sam2VideoProcessor
15
 
16
 
@@ -32,56 +32,25 @@ def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], di
32
  """Load video frames as PIL Images using transformers.video_utils if available,
33
  otherwise fall back to OpenCV. Returns (frames, info).
34
  """
35
- try:
36
- from transformers.video_utils import load_video # type: ignore
37
-
38
- frames, info = load_video(video_path_or_url)
39
- # Ensure PIL format
40
- pil_frames = []
41
- for fr in frames:
42
- if isinstance(fr, Image.Image):
43
- pil_frames.append(fr.convert("RGB"))
44
- else:
45
- pil_frames.append(Image.fromarray(fr).convert("RGB"))
46
- info = info if info is not None else {}
47
- # Ensure fps present when possible (fallback to cv2 probe)
48
- if "fps" not in info or not info.get("fps"):
49
- try:
50
- import cv2 # type: ignore
51
-
52
- cap = cv2.VideoCapture(video_path_or_url)
53
- fps_val = cap.get(cv2.CAP_PROP_FPS)
54
- cap.release()
55
- if fps_val and fps_val > 0:
56
- info["fps"] = float(fps_val)
57
- except Exception as e:
58
- print(f"Failed to render video with cv2: {e}")
59
- pass
60
- return pil_frames, info
61
- except Exception as e:
62
- print(f"Failed to load video with transformers.video_utils: {e}")
63
- # Fallback to OpenCV
64
- try:
65
- import cv2 # type: ignore
66
-
67
- cap = cv2.VideoCapture(video_path_or_url)
68
- frames = []
69
- while cap.isOpened():
70
- ret, frame = cap.read()
71
- if not ret:
72
- break
73
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
74
- frames.append(Image.fromarray(frame_rgb))
75
- # Gather fps if available
76
- fps_val = cap.get(cv2.CAP_PROP_FPS)
77
- cap.release()
78
- info = {
79
- "num_frames": len(frames),
80
- "fps": float(fps_val) if fps_val and fps_val > 0 else None,
81
- }
82
- return frames, info
83
- except Exception as e:
84
- raise RuntimeError(f"Failed to load video: {e}")
85
 
86
 
87
  def overlay_masks_on_frame(
@@ -196,14 +165,12 @@ def load_model_if_needed(GLOBAL_STATE: gr.State) -> tuple[AutoModel, Sam2VideoPr
196
  GLOBAL_STATE.dtype = dtype
197
  GLOBAL_STATE.model_repo_id = desired_repo
198
 
199
- return model, processor, device, dtype
200
-
201
 
202
  def ensure_session_for_current_model(GLOBAL_STATE: gr.State) -> None:
203
  """Ensure the model/processor match the selected repo and inference_session exists.
204
  If a video is already loaded, re-initialize the inference session when needed.
205
  """
206
- model, processor, device, dtype = load_model_if_needed(GLOBAL_STATE)
207
  desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
208
  if GLOBAL_STATE.inference_session is None or GLOBAL_STATE.session_repo_id != desired_repo:
209
  if GLOBAL_STATE.video_frames:
@@ -213,10 +180,10 @@ def ensure_session_for_current_model(GLOBAL_STATE: gr.State) -> None:
213
  GLOBAL_STATE.boxes_by_frame_obj.clear()
214
  GLOBAL_STATE.composited_frames.clear()
215
  GLOBAL_STATE.inference_session = None
216
- GLOBAL_STATE.inference_session = processor.init_video_session(
217
- inference_device=device,
218
  video_storage_device="cpu",
219
- dtype=dtype,
220
  )
221
  GLOBAL_STATE.session_repo_id = desired_repo
222
 
@@ -229,7 +196,7 @@ def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppSt
229
  GLOBAL_STATE.masks_by_frame = {}
230
  GLOBAL_STATE.color_by_obj = {}
231
 
232
- model, processor, device, dtype = load_model_if_needed(GLOBAL_STATE)
233
 
234
  # Gradio Video may provide a dict with 'name' or a direct file path
235
  video_path: Optional[str] = None
@@ -261,10 +228,10 @@ def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppSt
261
  # Try to capture original FPS if provided by loader
262
  GLOBAL_STATE.video_fps = float(fps_in)
263
  # Initialize session
264
- inference_session = processor.init_video_session(
265
- inference_device=device,
266
  video_storage_device="cpu",
267
- dtype=dtype,
268
  )
269
  GLOBAL_STATE.inference_session = inference_session
270
 
@@ -272,7 +239,7 @@ def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppSt
272
  max_idx = len(frames) - 1
273
  status = (
274
  f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. "
275
- f"Device: {device}, dtype: bfloat16"
276
  )
277
  return GLOBAL_STATE, 0, max_idx, first_frame, status
278
 
@@ -749,8 +716,6 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
749
  out_path = "/tmp/sam2_playback.mp4"
750
  # Prefer imageio with PyAV/ffmpeg to respect exact fps
751
  try:
752
- import cv2 # type: ignore
753
-
754
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
755
  writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
756
  for fr_bgr in frames_np:
 
3
  from copy import deepcopy
4
  from typing import Optional
5
 
6
+ import cv2
7
  import gradio as gr
8
  import numpy as np
9
  import spaces
 
11
  from gradio.themes import Soft
12
  from PIL import Image, ImageDraw
13
 
 
14
  from transformers import AutoModel, Sam2VideoProcessor
15
 
16
 
 
32
  """Load video frames as PIL Images using transformers.video_utils if available,
33
  otherwise fall back to OpenCV. Returns (frames, info).
34
  """
35
+
36
+ cap = cv2.VideoCapture(video_path_or_url)
37
+ frames = []
38
+ print("loading video frames")
39
+ while cap.isOpened():
40
+ ret, frame = cap.read()
41
+ if not ret:
42
+ break
43
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
44
+ frames.append(Image.fromarray(frame_rgb))
45
+ # Gather fps if available
46
+ fps_val = cap.get(cv2.CAP_PROP_FPS)
47
+ cap.release()
48
+ print("loaded video frames")
49
+ info = {
50
+ "num_frames": len(frames),
51
+ "fps": float(fps_val) if fps_val and fps_val > 0 else None,
52
+ }
53
+ return frames, info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
 
56
  def overlay_masks_on_frame(
 
165
  GLOBAL_STATE.dtype = dtype
166
  GLOBAL_STATE.model_repo_id = desired_repo
167
 
 
 
168
 
169
  def ensure_session_for_current_model(GLOBAL_STATE: gr.State) -> None:
170
  """Ensure the model/processor match the selected repo and inference_session exists.
171
  If a video is already loaded, re-initialize the inference session when needed.
172
  """
173
+ load_model_if_needed(GLOBAL_STATE)
174
  desired_repo = _model_repo_from_key(GLOBAL_STATE.model_repo_key)
175
  if GLOBAL_STATE.inference_session is None or GLOBAL_STATE.session_repo_id != desired_repo:
176
  if GLOBAL_STATE.video_frames:
 
180
  GLOBAL_STATE.boxes_by_frame_obj.clear()
181
  GLOBAL_STATE.composited_frames.clear()
182
  GLOBAL_STATE.inference_session = None
183
+ GLOBAL_STATE.inference_session = GLOBAL_STATE.processor.init_video_session(
184
+ inference_device=GLOBAL_STATE.device,
185
  video_storage_device="cpu",
186
+ dtype=GLOBAL_STATE.dtype,
187
  )
188
  GLOBAL_STATE.session_repo_id = desired_repo
189
 
 
196
  GLOBAL_STATE.masks_by_frame = {}
197
  GLOBAL_STATE.color_by_obj = {}
198
 
199
+ load_model_if_needed(GLOBAL_STATE)
200
 
201
  # Gradio Video may provide a dict with 'name' or a direct file path
202
  video_path: Optional[str] = None
 
228
  # Try to capture original FPS if provided by loader
229
  GLOBAL_STATE.video_fps = float(fps_in)
230
  # Initialize session
231
+ inference_session = GLOBAL_STATE.processor.init_video_session(
232
+ inference_device=GLOBAL_STATE.device,
233
  video_storage_device="cpu",
234
+ dtype=GLOBAL_STATE.dtype,
235
  )
236
  GLOBAL_STATE.inference_session = inference_session
237
 
 
239
  max_idx = len(frames) - 1
240
  status = (
241
  f"Loaded {len(frames)} frames @ {GLOBAL_STATE.video_fps or 'unknown'} fps{trimmed_note}. "
242
+ f"Device: {GLOBAL_STATE.device}, dtype: bfloat16"
243
  )
244
  return GLOBAL_STATE, 0, max_idx, first_frame, status
245
 
 
716
  out_path = "/tmp/sam2_playback.mp4"
717
  # Prefer imageio with PyAV/ffmpeg to respect exact fps
718
  try:
 
 
719
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
720
  writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
721
  for fr_bgr in frames_np: