NihalGazi commited on
Commit
7ffd52f
·
verified ·
1 Parent(s): 83956cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -44
app.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  import tempfile
5
  from tqdm import tqdm
6
  import gradio as gr
 
7
 
8
  # -----------------------------
9
  # Function to extract frames from the video.
@@ -12,15 +13,14 @@ def extract_frames(video_path):
12
  """
13
  Opens the video file and extracts all frames into a list.
14
  Logic:
15
- - Open the video with cv2.VideoCapture.
16
- - Read frames until no frame is returned.
17
  """
18
  cap = cv2.VideoCapture(video_path)
19
  frames = []
20
  while True:
21
- ret, frame = cap.read() # ret is True if a frame is successfully read.
22
  if not ret:
23
- break # Exit loop when no more frames.
24
  frames.append(frame)
25
  cap.release()
26
  return frames
@@ -32,109 +32,122 @@ def apply_style_propagation(frames, style_image_path):
32
  """
33
  Applies the style from the provided image onto each video frame.
34
  Logic:
35
- - Load and resize the style image to match video dimensions.
36
- - Use the style image as the first styled frame.
37
- - For each subsequent frame, compute optical flow between consecutive frames.
38
- - Warp the previous styled frame using the flow so that the style follows the motion.
39
  """
40
- # Load and resize the style image.
41
  style_image = cv2.imread(style_image_path)
42
  h, w = frames[0].shape[:2]
43
  style_image = cv2.resize(style_image, (w, h))
44
 
45
- # Use the style image as the first styled frame.
46
  styled_frames = [style_image]
47
 
48
  # Convert the first frame to grayscale.
49
  prev_gray = cv2.cvtColor(frames[0], cv2.COLOR_BGR2GRAY)
50
 
51
- # Process each subsequent frame.
52
  for i in tqdm(range(1, len(frames)), desc="Propagating style"):
53
  curr_gray = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
54
 
55
- # Compute dense optical flow using the Farneback method.
56
  flow = cv2.calcOpticalFlowFarneback(
57
  prev_gray, curr_gray, None,
58
  pyr_scale=0.5, levels=3, winsize=15,
59
  iterations=3, poly_n=5, poly_sigma=1.2, flags=0
60
  )
61
 
62
- # Create a grid of (x,y) coordinates for each pixel.
63
  grid_x, grid_y = np.meshgrid(np.arange(w), np.arange(h))
64
-
65
- # Add the flow vectors to the coordinate grid.
66
  map_x = (grid_x + flow[..., 0]).astype(np.float32)
67
  map_y = (grid_y + flow[..., 1]).astype(np.float32)
68
 
69
- # Warp the last styled frame using the computed mapping.
70
  warped_styled = cv2.remap(styled_frames[-1], map_x, map_y, interpolation=cv2.INTER_LINEAR)
71
  styled_frames.append(warped_styled)
72
 
73
- # Update the previous grayscale frame.
74
  prev_gray = curr_gray
75
 
76
  return styled_frames
77
 
78
  # -----------------------------
79
- # Function to save a list of frames as a video file.
80
  # -----------------------------
81
- def save_video(frames, output_path, fps=30):
82
  """
83
- Combines frames into a video and saves it.
84
  Logic:
85
- - Create a VideoWriter with the specified FPS and frame size.
86
- - Write each frame sequentially to the video file.
87
  """
88
  h, w, _ = frames[0].shape
89
- fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Use 'mp4v' codec for MP4.
90
  out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
91
-
92
  for frame in frames:
93
  out.write(frame)
94
  out.release()
95
 
96
  # -----------------------------
97
- # Main processing function for the Gradio interface.
98
  # -----------------------------
99
  def process_video(video_file, style_image_file, fps=30):
100
  """
101
- Processes the video by propagating the style from the image.
 
 
102
  Inputs:
103
  - video_file: Uploaded video file.
104
  - style_image_file: Uploaded stylized keyframe image.
105
- - fps: Frames per second for the output video.
 
106
  Returns:
107
- - Path to the generated styled video.
108
  """
109
- # For the video file, we expect a file path.
110
  video_path = video_file if isinstance(video_file, str) else video_file["name"]
111
 
112
- # For the style image, Gradio might return a numpy array.
113
  if isinstance(style_image_file, str):
114
  style_image_path = style_image_file
115
  elif isinstance(style_image_file, dict) and "name" in style_image_file:
116
  style_image_path = style_image_file["name"]
117
  elif isinstance(style_image_file, np.ndarray):
118
- # If the image is a numpy array, save it to a temporary file.
119
- tmp_path = os.path.join(tempfile.gettempdir(), "temp_style_image.jpeg")
120
- # Gradio images are usually in RGB; OpenCV uses BGR.
121
- cv2.imwrite(tmp_path, cv2.cvtColor(style_image_file, cv2.COLOR_RGB2BGR))
122
- style_image_path = tmp_path
123
  else:
124
- return "Error: Unsupported style image file format."
125
 
126
- # Extract frames from the input video.
127
  frames = extract_frames(video_path)
128
  if not frames:
129
  return "Error: No frames extracted from the video."
130
 
131
- # Propagate the style image across the frames.
132
  styled_frames = apply_style_propagation(frames, style_image_path)
133
 
134
- # Save the styled frames into a new video file in a temporary directory.
135
  with tempfile.TemporaryDirectory() as tmpdir:
136
- output_video_path = os.path.join(tmpdir, "stylized_video.mp4")
137
- save_video(styled_frames, output_video_path, fps=fps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  return output_video_path
139
 
140
  # -----------------------------
@@ -150,9 +163,9 @@ iface = gr.Interface(
150
  outputs=gr.Video(label="Styled Video"),
151
  title="Optical Flow Style Propagation",
152
  description=(
153
- "Upload a video and a stylized keyframe image. "
154
- "The style from the keyframe is propagated across the video using optical flow and warping, "
155
- "and the result is output as a new video."
156
  )
157
  )
158
 
 
4
  import tempfile
5
  from tqdm import tqdm
6
  import gradio as gr
7
+ import ffmpeg
8
 
9
  # -----------------------------
10
  # Function to extract frames from the video.
 
13
  """
14
  Opens the video file and extracts all frames into a list.
15
  Logic:
16
+ - Use cv2.VideoCapture to read the video frame-by-frame.
 
17
  """
18
  cap = cv2.VideoCapture(video_path)
19
  frames = []
20
  while True:
21
+ ret, frame = cap.read() # Read a frame.
22
  if not ret:
23
+ break # Stop if no more frames.
24
  frames.append(frame)
25
  cap.release()
26
  return frames
 
32
  """
33
  Applies the style from the provided image onto each video frame.
34
  Logic:
35
+ - Load and resize the style image to match the video dimensions.
36
+ - Use the style image as the starting point.
37
+ - For each subsequent frame, compute the dense optical flow between the previous and current frame.
38
+ - Warp the previous styled frame so that the style follows the motion.
39
  """
40
+ # Load the style image and resize to match frame dimensions.
41
  style_image = cv2.imread(style_image_path)
42
  h, w = frames[0].shape[:2]
43
  style_image = cv2.resize(style_image, (w, h))
44
 
45
+ # The first styled frame is the style image.
46
  styled_frames = [style_image]
47
 
48
  # Convert the first frame to grayscale.
49
  prev_gray = cv2.cvtColor(frames[0], cv2.COLOR_BGR2GRAY)
50
 
51
+ # Process subsequent frames.
52
  for i in tqdm(range(1, len(frames)), desc="Propagating style"):
53
  curr_gray = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
54
 
55
+ # Compute dense optical flow using Farneback's method.
56
  flow = cv2.calcOpticalFlowFarneback(
57
  prev_gray, curr_gray, None,
58
  pyr_scale=0.5, levels=3, winsize=15,
59
  iterations=3, poly_n=5, poly_sigma=1.2, flags=0
60
  )
61
 
62
+ # Create a coordinate grid.
63
  grid_x, grid_y = np.meshgrid(np.arange(w), np.arange(h))
 
 
64
  map_x = (grid_x + flow[..., 0]).astype(np.float32)
65
  map_y = (grid_y + flow[..., 1]).astype(np.float32)
66
 
67
+ # Warp the last styled frame.
68
  warped_styled = cv2.remap(styled_frames[-1], map_x, map_y, interpolation=cv2.INTER_LINEAR)
69
  styled_frames.append(warped_styled)
70
 
71
+ # Update previous frame.
72
  prev_gray = curr_gray
73
 
74
  return styled_frames
75
 
76
  # -----------------------------
77
+ # Function to save video frames using OpenCV.
78
  # -----------------------------
79
+ def save_video_cv2(frames, output_path, fps=30):
80
  """
81
+ Saves a list of frames as a video file.
82
  Logic:
83
+ - Uses cv2.VideoWriter with codec 'mp4v' to create a temporary video file.
 
84
  """
85
  h, w, _ = frames[0].shape
86
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
87
  out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
 
88
  for frame in frames:
89
  out.write(frame)
90
  out.release()
91
 
92
  # -----------------------------
93
+ # Main processing function for Gradio.
94
  # -----------------------------
95
  def process_video(video_file, style_image_file, fps=30):
96
  """
97
+ Processes the video by applying the style image via optical flow.
98
+ Then re-encodes the video using FFmpeg (H.264, yuv420p) for web compatibility.
99
+
100
  Inputs:
101
  - video_file: Uploaded video file.
102
  - style_image_file: Uploaded stylized keyframe image.
103
+ - fps: Output frames per second.
104
+
105
  Returns:
106
+ - Path to the re-encoded, web-playable video.
107
  """
108
+ # Get the video file path.
109
  video_path = video_file if isinstance(video_file, str) else video_file["name"]
110
 
111
+ # Process style image input: if it's a numpy array, save it to a temporary file.
112
  if isinstance(style_image_file, str):
113
  style_image_path = style_image_file
114
  elif isinstance(style_image_file, dict) and "name" in style_image_file:
115
  style_image_path = style_image_file["name"]
116
  elif isinstance(style_image_file, np.ndarray):
117
+ tmp_style_path = os.path.join(tempfile.gettempdir(), "temp_style_image.jpeg")
118
+ # Convert from RGB (Gradio) to BGR (OpenCV)
119
+ cv2.imwrite(tmp_style_path, cv2.cvtColor(style_image_file, cv2.COLOR_RGB2BGR))
120
+ style_image_path = tmp_style_path
 
121
  else:
122
+ return "Error: Unsupported style image format."
123
 
124
+ # Extract frames from the video.
125
  frames = extract_frames(video_path)
126
  if not frames:
127
  return "Error: No frames extracted from the video."
128
 
129
+ # Propagate the style across video frames.
130
  styled_frames = apply_style_propagation(frames, style_image_path)
131
 
132
+ # Use a temporary directory for processing.
133
  with tempfile.TemporaryDirectory() as tmpdir:
134
+ # Save the raw styled video using OpenCV.
135
+ temp_video_path = os.path.join(tmpdir, "temp_video.mp4")
136
+ save_video_cv2(styled_frames, temp_video_path, fps=fps)
137
+
138
+ # Re-encode using FFmpeg to produce a web-friendly video.
139
+ output_video_path = os.path.join(tmpdir, "output_video.mp4")
140
+ try:
141
+ (
142
+ ffmpeg
143
+ .input(temp_video_path)
144
+ .output(output_video_path, vcodec='libx264', pix_fmt='yuv420p', r=fps)
145
+ .run(overwrite_output=True, quiet=True)
146
+ )
147
+ except ffmpeg.Error as e:
148
+ print("FFmpeg error:", e)
149
+ return "Error during video re-encoding."
150
+
151
  return output_video_path
152
 
153
  # -----------------------------
 
163
  outputs=gr.Video(label="Styled Video"),
164
  title="Optical Flow Style Propagation",
165
  description=(
166
+ "Upload a video and a stylized keyframe image. The style from the keyframe is propagated "
167
+ "across the video using optical flow and warping. The resulting video is re-encoded to be "
168
+ "web-friendly."
169
  )
170
  )
171