NihalGazi commited on
Commit
0e04a39
·
verified ·
1 Parent(s): 73e6bc5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -32
app.py CHANGED
@@ -9,9 +9,6 @@ import ffmpeg
9
  def extract_frames(video_path):
10
  """
11
  Extracts all frames from the input video.
12
- Logic:
13
- - Open the video file using cv2.VideoCapture.
14
- - Read frames until the video ends.
15
  """
16
  cap = cv2.VideoCapture(video_path)
17
  frames = []
@@ -26,40 +23,102 @@ def extract_frames(video_path):
26
 
27
  def apply_style_propagation(frames, style_image_path):
28
  """
29
- Applies the style from the provided image to each video frame using optical flow.
30
- Logic:
31
- - Load and resize the style image to match the frame dimensions.
32
- - Use the style image as the first styled frame.
33
- - For each subsequent frame, compute dense optical flow between consecutive frames.
34
- - Warp the previously styled frame using the computed flow.
35
- - Clip mapping coordinates to avoid out-of-bound values.
36
  """
 
37
  style_image = cv2.imread(style_image_path)
38
  if style_image is None:
39
  raise ValueError(f"Failed to load style image from {style_image_path}")
40
-
41
  h, w = frames[0].shape[:2]
42
  style_image = cv2.resize(style_image, (w, h))
 
 
43
 
44
  styled_frames = [style_image]
45
  prev_gray = cv2.cvtColor(frames[0], cv2.COLOR_BGR2GRAY)
46
 
 
 
 
 
 
 
47
  for i in tqdm(range(1, len(frames)), desc="Propagating style"):
 
48
  curr_gray = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
49
  flow = cv2.calcOpticalFlowFarneback(
50
  prev_gray, curr_gray, None,
51
  pyr_scale=0.5, levels=3, winsize=15,
52
  iterations=3, poly_n=5, poly_sigma=1.2, flags=0
53
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  grid_x, grid_y = np.meshgrid(np.arange(w), np.arange(h))
55
- map_x = grid_x + flow[..., 0]
56
- map_y = grid_y + flow[..., 1]
57
- # Clip mapping coordinates to valid pixel indices.
58
  map_x = np.clip(map_x, 0, w - 1).astype(np.float32)
59
  map_y = np.clip(map_y, 0, h - 1).astype(np.float32)
60
 
 
61
  warped_styled = cv2.remap(styled_frames[-1], map_x, map_y, interpolation=cv2.INTER_LINEAR)
62
- styled_frames.append(warped_styled)
 
 
 
 
 
 
 
 
 
 
 
 
63
  prev_gray = curr_gray
64
 
65
  print(f"Propagated style to {len(styled_frames)} frames.")
@@ -70,8 +129,6 @@ def apply_style_propagation(frames, style_image_path):
70
  def save_video_cv2(frames, output_path, fps=30):
71
  """
72
  Saves a list of frames as a video using OpenCV.
73
- Logic:
74
- - Use cv2.VideoWriter with codec 'mp4v' to create a temporary video file.
75
  """
76
  h, w, _ = frames[0].shape
77
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
@@ -84,21 +141,22 @@ def save_video_cv2(frames, output_path, fps=30):
84
 
85
  def process_video(video_file, style_image_file, fps=30, target_width=0, target_height=0):
86
  """
87
- Processes the input video by applying the style image via optical flow,
88
- optionally downscaling the video and style image to a specified resolution.
 
89
  Then re-encodes the video with FFmpeg for web compatibility.
90
 
91
- Inputs:
92
  - video_file: The input video file.
93
  - style_image_file: The stylized keyframe image.
94
  - fps: Output frames per second.
95
- - target_width: Target width for downscaling (0 means no downscale).
96
- - target_height: Target height for downscaling (0 means no downscale).
97
 
98
  Returns:
99
- - Path to the final, web-playable video.
100
  """
101
- # Determine video file path.
102
  video_path = video_file if isinstance(video_file, str) else video_file["name"]
103
 
104
  # Process the style image input.
@@ -113,7 +171,7 @@ def process_video(video_file, style_image_file, fps=30, target_width=0, target_h
113
  else:
114
  return "Error: Unsupported style image format."
115
 
116
- # Extract frames from the input video.
117
  frames = extract_frames(video_path)
118
  if not frames:
119
  return "Error: No frames extracted from the video."
@@ -121,19 +179,19 @@ def process_video(video_file, style_image_file, fps=30, target_width=0, target_h
121
  original_h, original_w = frames[0].shape[:2]
122
  print(f"Original video resolution: {original_w}x{original_h}")
123
 
124
- # Downscale if target dimensions are provided (non-zero).
125
  if target_width > 0 and target_height > 0:
126
  print(f"Downscaling frames to resolution: {target_width}x{target_height}")
127
  frames = [cv2.resize(frame, (target_width, target_height)) for frame in frames]
128
  else:
129
  print("No downscaling applied. Using original resolution.")
130
 
131
- # Propagate style.
132
  styled_frames = apply_style_propagation(frames, style_image_path)
133
 
134
  # Save intermediate video using OpenCV to a named temporary file.
135
  temp_video_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
136
- temp_video_file.close() # Close so OpenCV can write to this file.
137
  temp_video_path = temp_video_file.name
138
  save_video_cv2(styled_frames, temp_video_path, fps=fps)
139
 
@@ -173,11 +231,14 @@ iface = gr.Interface(
173
  gr.Slider(minimum=0, maximum=1080, step=1, value=0, label="Target Height (0 for original)")
174
  ],
175
  outputs=gr.Video(label="Styled Video"),
176
- title="Optical Flow Style Propagation with Optional Downscaling",
177
  description=(
178
- "Upload a video and a stylized keyframe image. Optionally downscale both to a target resolution "
179
- "by specifying width and height (set both to 0 for original resolution). "
180
- "The style from the keyframe is propagated across the video using optical flow and warping. "
 
 
 
181
  "The output video is re-encoded for web compatibility."
182
  )
183
  )
 
9
  def extract_frames(video_path):
10
  """
11
  Extracts all frames from the input video.
 
 
 
12
  """
13
  cap = cv2.VideoCapture(video_path)
14
  frames = []
 
23
 
24
  def apply_style_propagation(frames, style_image_path):
25
  """
26
+ Applies the style from the provided keyframe image to every frame using optical flow,
27
+ with additional corrections:
28
+ - Median filtering of flow components.
29
+ - Patch-based fallback for blocks with extreme flow.
30
+ - Temporal reset blending with the original style.
31
+ - Sharpening after warping.
 
32
  """
33
+ # Load and resize the style image to match video dimensions.
34
  style_image = cv2.imread(style_image_path)
35
  if style_image is None:
36
  raise ValueError(f"Failed to load style image from {style_image_path}")
 
37
  h, w = frames[0].shape[:2]
38
  style_image = cv2.resize(style_image, (w, h))
39
+ # Keep a copy for temporal re-anchoring.
40
+ original_styled = style_image.copy()
41
 
42
  styled_frames = [style_image]
43
  prev_gray = cv2.cvtColor(frames[0], cv2.COLOR_BGR2GRAY)
44
 
45
+ # Parameters for corrections:
46
+ reset_interval = 30 # Every 30 frames, blend with original style.
47
+ block_size = 16 # Size of block for patch matching.
48
+ patch_threshold = 10 # If mean flow magnitude in a block exceeds this, use patch matching.
49
+ search_margin = 10 # Margin around block for patch matching.
50
+
51
  for i in tqdm(range(1, len(frames)), desc="Propagating style"):
52
+ # Compute optical flow between the previous and current grayscale frames.
53
  curr_gray = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
54
  flow = cv2.calcOpticalFlowFarneback(
55
  prev_gray, curr_gray, None,
56
  pyr_scale=0.5, levels=3, winsize=15,
57
  iterations=3, poly_n=5, poly_sigma=1.2, flags=0
58
  )
59
+
60
+ # --- Method 3: Median filtering of the flow components ---
61
+ flow_x = flow[..., 0]
62
+ flow_y = flow[..., 1]
63
+ flow_x_filtered = cv2.medianBlur(flow_x, 3)
64
+ flow_y_filtered = cv2.medianBlur(flow_y, 3)
65
+ flow_filtered = np.dstack((flow_x_filtered, flow_y_filtered))
66
+
67
+ # --- Method 4: Patch-based fallback for extreme flow ---
68
+ flow_corrected = flow_filtered.copy()
69
+ for by in range(0, h, block_size):
70
+ for bx in range(0, w, block_size):
71
+ # Define block region (handle edges)
72
+ y1, y2 = by, min(by + block_size, h)
73
+ x1, x2 = bx, min(bx + block_size, w)
74
+ block_flow = flow_filtered[y1:y2, x1:x2]
75
+ # Compute mean magnitude in the block.
76
+ mag = np.sqrt(block_flow[..., 0]**2 + block_flow[..., 1]**2)
77
+ mean_mag = np.mean(mag)
78
+ if mean_mag > patch_threshold:
79
+ # Use patch matching to recalc flow for this block.
80
+ patch = prev_gray[y1:y2, x1:x2]
81
+ # Define search region in current frame.
82
+ sx1 = max(x1 - search_margin, 0)
83
+ sy1 = max(by - search_margin, 0)
84
+ sx2 = min(x2 + search_margin, w)
85
+ sy2 = min(y2 + search_margin, h)
86
+ search_region = curr_gray[sy1:sy2, sx1:sx2]
87
+ if search_region.shape[0] < patch.shape[0] or search_region.shape[1] < patch.shape[1]:
88
+ continue
89
+ res = cv2.matchTemplate(search_region, patch, cv2.TM_SQDIFF_NORMED)
90
+ _, _, min_loc, _ = cv2.minMaxLoc(res)
91
+ best_x = sx1 + min_loc[0]
92
+ best_y = sy1 + min_loc[1]
93
+ # Calculate offset relative to block's top-left corner.
94
+ offset_x = best_x - x1
95
+ offset_y = best_y - by
96
+ # Override flow for the entire block.
97
+ flow_corrected[y1:y2, x1:x2, 0] = offset_x
98
+ flow_corrected[y1:y2, x1:x2, 1] = offset_y
99
+
100
+ # Compute mapping coordinates.
101
  grid_x, grid_y = np.meshgrid(np.arange(w), np.arange(h))
102
+ map_x = grid_x + flow_corrected[..., 0]
103
+ map_y = grid_y + flow_corrected[..., 1]
 
104
  map_x = np.clip(map_x, 0, w - 1).astype(np.float32)
105
  map_y = np.clip(map_y, 0, h - 1).astype(np.float32)
106
 
107
+ # Warp the previous styled frame using the computed mapping.
108
  warped_styled = cv2.remap(styled_frames[-1], map_x, map_y, interpolation=cv2.INTER_LINEAR)
109
+
110
+ # --- Method 2: Temporal Reset/Re-anchoring ---
111
+ if i % reset_interval == 0:
112
+ # Blend the current warped result with the original styled keyframe.
113
+ warped_styled = cv2.addWeighted(warped_styled, 0.7, original_styled, 0.3, 0)
114
+
115
+ # --- Method 5: Sharpening Post-Warping ---
116
+ kernel = np.array([[0, -1, 0],
117
+ [-1, 5, -1],
118
+ [0, -1, 0]], dtype=np.float32)
119
+ warped_sharpened = cv2.filter2D(warped_styled, -1, kernel)
120
+
121
+ styled_frames.append(warped_sharpened)
122
  prev_gray = curr_gray
123
 
124
  print(f"Propagated style to {len(styled_frames)} frames.")
 
129
  def save_video_cv2(frames, output_path, fps=30):
130
  """
131
  Saves a list of frames as a video using OpenCV.
 
 
132
  """
133
  h, w, _ = frames[0].shape
134
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
 
141
 
142
  def process_video(video_file, style_image_file, fps=30, target_width=0, target_height=0):
143
  """
144
+ Processes the input video by applying the style image via optical flow propagation,
145
+ with additional corrections (methods 2, 3, 4, and 5).
146
+ Optionally downscale the video and style image to the specified resolution.
147
  Then re-encodes the video with FFmpeg for web compatibility.
148
 
149
+ Parameters:
150
  - video_file: The input video file.
151
  - style_image_file: The stylized keyframe image.
152
  - fps: Output frames per second.
153
+ - target_width: Target width for downscaling (0 for original).
154
+ - target_height: Target height for downscaling (0 for original).
155
 
156
  Returns:
157
+ - Path to the final output video.
158
  """
159
+ # Get the video file path.
160
  video_path = video_file if isinstance(video_file, str) else video_file["name"]
161
 
162
  # Process the style image input.
 
171
  else:
172
  return "Error: Unsupported style image format."
173
 
174
+ # Extract frames from the video.
175
  frames = extract_frames(video_path)
176
  if not frames:
177
  return "Error: No frames extracted from the video."
 
179
  original_h, original_w = frames[0].shape[:2]
180
  print(f"Original video resolution: {original_w}x{original_h}")
181
 
182
+ # Downscale if target dimensions are provided.
183
  if target_width > 0 and target_height > 0:
184
  print(f"Downscaling frames to resolution: {target_width}x{target_height}")
185
  frames = [cv2.resize(frame, (target_width, target_height)) for frame in frames]
186
  else:
187
  print("No downscaling applied. Using original resolution.")
188
 
189
+ # Propagate the style using our enhanced method.
190
  styled_frames = apply_style_propagation(frames, style_image_path)
191
 
192
  # Save intermediate video using OpenCV to a named temporary file.
193
  temp_video_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
194
+ temp_video_file.close()
195
  temp_video_path = temp_video_file.name
196
  save_video_cv2(styled_frames, temp_video_path, fps=fps)
197
 
 
231
  gr.Slider(minimum=0, maximum=1080, step=1, value=0, label="Target Height (0 for original)")
232
  ],
233
  outputs=gr.Video(label="Styled Video"),
234
+ title="Optical Flow Style Propagation with Corrections",
235
  description=(
236
+ "Upload a video and a stylized keyframe image. Optionally downscale both to a target resolution. "
237
+ "The style is propagated using optical flow with additional corrections:\n"
238
+ " Temporal re-anchoring\n"
239
+ "• Median filtering of the flow\n"
240
+ "• Patch-based flow correction\n"
241
+ "• Post-warp sharpening\n"
242
  "The output video is re-encoded for web compatibility."
243
  )
244
  )