Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ import os
|
|
4 |
import tempfile
|
5 |
from tqdm import tqdm
|
6 |
import gradio as gr
|
|
|
7 |
|
8 |
# -----------------------------
|
9 |
# Function to extract frames from the video.
|
@@ -12,15 +13,14 @@ def extract_frames(video_path):
|
|
12 |
"""
|
13 |
Opens the video file and extracts all frames into a list.
|
14 |
Logic:
|
15 |
-
-
|
16 |
-
- Read frames until no frame is returned.
|
17 |
"""
|
18 |
cap = cv2.VideoCapture(video_path)
|
19 |
frames = []
|
20 |
while True:
|
21 |
-
ret, frame = cap.read() #
|
22 |
if not ret:
|
23 |
-
break #
|
24 |
frames.append(frame)
|
25 |
cap.release()
|
26 |
return frames
|
@@ -32,109 +32,122 @@ def apply_style_propagation(frames, style_image_path):
|
|
32 |
"""
|
33 |
Applies the style from the provided image onto each video frame.
|
34 |
Logic:
|
35 |
-
- Load and resize the style image to match video dimensions.
|
36 |
-
- Use the style image as the
|
37 |
-
- For each subsequent frame, compute optical flow between
|
38 |
-
- Warp the previous styled frame
|
39 |
"""
|
40 |
-
# Load and resize
|
41 |
style_image = cv2.imread(style_image_path)
|
42 |
h, w = frames[0].shape[:2]
|
43 |
style_image = cv2.resize(style_image, (w, h))
|
44 |
|
45 |
-
#
|
46 |
styled_frames = [style_image]
|
47 |
|
48 |
# Convert the first frame to grayscale.
|
49 |
prev_gray = cv2.cvtColor(frames[0], cv2.COLOR_BGR2GRAY)
|
50 |
|
51 |
-
# Process
|
52 |
for i in tqdm(range(1, len(frames)), desc="Propagating style"):
|
53 |
curr_gray = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
|
54 |
|
55 |
-
# Compute dense optical flow using
|
56 |
flow = cv2.calcOpticalFlowFarneback(
|
57 |
prev_gray, curr_gray, None,
|
58 |
pyr_scale=0.5, levels=3, winsize=15,
|
59 |
iterations=3, poly_n=5, poly_sigma=1.2, flags=0
|
60 |
)
|
61 |
|
62 |
-
# Create a grid
|
63 |
grid_x, grid_y = np.meshgrid(np.arange(w), np.arange(h))
|
64 |
-
|
65 |
-
# Add the flow vectors to the coordinate grid.
|
66 |
map_x = (grid_x + flow[..., 0]).astype(np.float32)
|
67 |
map_y = (grid_y + flow[..., 1]).astype(np.float32)
|
68 |
|
69 |
-
# Warp the last styled frame
|
70 |
warped_styled = cv2.remap(styled_frames[-1], map_x, map_y, interpolation=cv2.INTER_LINEAR)
|
71 |
styled_frames.append(warped_styled)
|
72 |
|
73 |
-
# Update
|
74 |
prev_gray = curr_gray
|
75 |
|
76 |
return styled_frames
|
77 |
|
78 |
# -----------------------------
|
79 |
-
# Function to save
|
80 |
# -----------------------------
|
81 |
-
def
|
82 |
"""
|
83 |
-
|
84 |
Logic:
|
85 |
-
-
|
86 |
-
- Write each frame sequentially to the video file.
|
87 |
"""
|
88 |
h, w, _ = frames[0].shape
|
89 |
-
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
90 |
out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
|
91 |
-
|
92 |
for frame in frames:
|
93 |
out.write(frame)
|
94 |
out.release()
|
95 |
|
96 |
# -----------------------------
|
97 |
-
# Main processing function for
|
98 |
# -----------------------------
|
99 |
def process_video(video_file, style_image_file, fps=30):
|
100 |
"""
|
101 |
-
Processes the video by
|
|
|
|
|
102 |
Inputs:
|
103 |
- video_file: Uploaded video file.
|
104 |
- style_image_file: Uploaded stylized keyframe image.
|
105 |
-
- fps:
|
|
|
106 |
Returns:
|
107 |
-
- Path to the
|
108 |
"""
|
109 |
-
#
|
110 |
video_path = video_file if isinstance(video_file, str) else video_file["name"]
|
111 |
|
112 |
-
#
|
113 |
if isinstance(style_image_file, str):
|
114 |
style_image_path = style_image_file
|
115 |
elif isinstance(style_image_file, dict) and "name" in style_image_file:
|
116 |
style_image_path = style_image_file["name"]
|
117 |
elif isinstance(style_image_file, np.ndarray):
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
style_image_path = tmp_path
|
123 |
else:
|
124 |
-
return "Error: Unsupported style image
|
125 |
|
126 |
-
# Extract frames from the
|
127 |
frames = extract_frames(video_path)
|
128 |
if not frames:
|
129 |
return "Error: No frames extracted from the video."
|
130 |
|
131 |
-
# Propagate the style
|
132 |
styled_frames = apply_style_propagation(frames, style_image_path)
|
133 |
|
134 |
-
#
|
135 |
with tempfile.TemporaryDirectory() as tmpdir:
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
return output_video_path
|
139 |
|
140 |
# -----------------------------
|
@@ -150,9 +163,9 @@ iface = gr.Interface(
|
|
150 |
outputs=gr.Video(label="Styled Video"),
|
151 |
title="Optical Flow Style Propagation",
|
152 |
description=(
|
153 |
-
"Upload a video and a stylized keyframe image. "
|
154 |
-
"
|
155 |
-
"
|
156 |
)
|
157 |
)
|
158 |
|
|
|
4 |
import tempfile
|
5 |
from tqdm import tqdm
|
6 |
import gradio as gr
|
7 |
+
import ffmpeg
|
8 |
|
9 |
# -----------------------------
|
10 |
# Function to extract frames from the video.
|
|
|
13 |
"""
|
14 |
Opens the video file and extracts all frames into a list.
|
15 |
Logic:
|
16 |
+
- Use cv2.VideoCapture to read the video frame-by-frame.
|
|
|
17 |
"""
|
18 |
cap = cv2.VideoCapture(video_path)
|
19 |
frames = []
|
20 |
while True:
|
21 |
+
ret, frame = cap.read() # Read a frame.
|
22 |
if not ret:
|
23 |
+
break # Stop if no more frames.
|
24 |
frames.append(frame)
|
25 |
cap.release()
|
26 |
return frames
|
|
|
32 |
"""
|
33 |
Applies the style from the provided image onto each video frame.
|
34 |
Logic:
|
35 |
+
- Load and resize the style image to match the video dimensions.
|
36 |
+
- Use the style image as the starting point.
|
37 |
+
- For each subsequent frame, compute the dense optical flow between the previous and current frame.
|
38 |
+
- Warp the previous styled frame so that the style follows the motion.
|
39 |
"""
|
40 |
+
# Load the style image and resize to match frame dimensions.
|
41 |
style_image = cv2.imread(style_image_path)
|
42 |
h, w = frames[0].shape[:2]
|
43 |
style_image = cv2.resize(style_image, (w, h))
|
44 |
|
45 |
+
# The first styled frame is the style image.
|
46 |
styled_frames = [style_image]
|
47 |
|
48 |
# Convert the first frame to grayscale.
|
49 |
prev_gray = cv2.cvtColor(frames[0], cv2.COLOR_BGR2GRAY)
|
50 |
|
51 |
+
# Process subsequent frames.
|
52 |
for i in tqdm(range(1, len(frames)), desc="Propagating style"):
|
53 |
curr_gray = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
|
54 |
|
55 |
+
# Compute dense optical flow using Farneback's method.
|
56 |
flow = cv2.calcOpticalFlowFarneback(
|
57 |
prev_gray, curr_gray, None,
|
58 |
pyr_scale=0.5, levels=3, winsize=15,
|
59 |
iterations=3, poly_n=5, poly_sigma=1.2, flags=0
|
60 |
)
|
61 |
|
62 |
+
# Create a coordinate grid.
|
63 |
grid_x, grid_y = np.meshgrid(np.arange(w), np.arange(h))
|
|
|
|
|
64 |
map_x = (grid_x + flow[..., 0]).astype(np.float32)
|
65 |
map_y = (grid_y + flow[..., 1]).astype(np.float32)
|
66 |
|
67 |
+
# Warp the last styled frame.
|
68 |
warped_styled = cv2.remap(styled_frames[-1], map_x, map_y, interpolation=cv2.INTER_LINEAR)
|
69 |
styled_frames.append(warped_styled)
|
70 |
|
71 |
+
# Update previous frame.
|
72 |
prev_gray = curr_gray
|
73 |
|
74 |
return styled_frames
|
75 |
|
76 |
# -----------------------------
|
77 |
+
# Function to save video frames using OpenCV.
|
78 |
# -----------------------------
|
79 |
+
def save_video_cv2(frames, output_path, fps=30):
|
80 |
"""
|
81 |
+
Saves a list of frames as a video file.
|
82 |
Logic:
|
83 |
+
- Uses cv2.VideoWriter with codec 'mp4v' to create a temporary video file.
|
|
|
84 |
"""
|
85 |
h, w, _ = frames[0].shape
|
86 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
87 |
out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
|
|
|
88 |
for frame in frames:
|
89 |
out.write(frame)
|
90 |
out.release()
|
91 |
|
92 |
# -----------------------------
|
93 |
+
# Main processing function for Gradio.
|
94 |
# -----------------------------
|
95 |
def process_video(video_file, style_image_file, fps=30):
|
96 |
"""
|
97 |
+
Processes the video by applying the style image via optical flow.
|
98 |
+
Then re-encodes the video using FFmpeg (H.264, yuv420p) for web compatibility.
|
99 |
+
|
100 |
Inputs:
|
101 |
- video_file: Uploaded video file.
|
102 |
- style_image_file: Uploaded stylized keyframe image.
|
103 |
+
- fps: Output frames per second.
|
104 |
+
|
105 |
Returns:
|
106 |
+
- Path to the re-encoded, web-playable video.
|
107 |
"""
|
108 |
+
# Get the video file path.
|
109 |
video_path = video_file if isinstance(video_file, str) else video_file["name"]
|
110 |
|
111 |
+
# Process style image input: if it's a numpy array, save it to a temporary file.
|
112 |
if isinstance(style_image_file, str):
|
113 |
style_image_path = style_image_file
|
114 |
elif isinstance(style_image_file, dict) and "name" in style_image_file:
|
115 |
style_image_path = style_image_file["name"]
|
116 |
elif isinstance(style_image_file, np.ndarray):
|
117 |
+
tmp_style_path = os.path.join(tempfile.gettempdir(), "temp_style_image.jpeg")
|
118 |
+
# Convert from RGB (Gradio) to BGR (OpenCV)
|
119 |
+
cv2.imwrite(tmp_style_path, cv2.cvtColor(style_image_file, cv2.COLOR_RGB2BGR))
|
120 |
+
style_image_path = tmp_style_path
|
|
|
121 |
else:
|
122 |
+
return "Error: Unsupported style image format."
|
123 |
|
124 |
+
# Extract frames from the video.
|
125 |
frames = extract_frames(video_path)
|
126 |
if not frames:
|
127 |
return "Error: No frames extracted from the video."
|
128 |
|
129 |
+
# Propagate the style across video frames.
|
130 |
styled_frames = apply_style_propagation(frames, style_image_path)
|
131 |
|
132 |
+
# Use a temporary directory for processing.
|
133 |
with tempfile.TemporaryDirectory() as tmpdir:
|
134 |
+
# Save the raw styled video using OpenCV.
|
135 |
+
temp_video_path = os.path.join(tmpdir, "temp_video.mp4")
|
136 |
+
save_video_cv2(styled_frames, temp_video_path, fps=fps)
|
137 |
+
|
138 |
+
# Re-encode using FFmpeg to produce a web-friendly video.
|
139 |
+
output_video_path = os.path.join(tmpdir, "output_video.mp4")
|
140 |
+
try:
|
141 |
+
(
|
142 |
+
ffmpeg
|
143 |
+
.input(temp_video_path)
|
144 |
+
.output(output_video_path, vcodec='libx264', pix_fmt='yuv420p', r=fps)
|
145 |
+
.run(overwrite_output=True, quiet=True)
|
146 |
+
)
|
147 |
+
except ffmpeg.Error as e:
|
148 |
+
print("FFmpeg error:", e)
|
149 |
+
return "Error during video re-encoding."
|
150 |
+
|
151 |
return output_video_path
|
152 |
|
153 |
# -----------------------------
|
|
|
163 |
outputs=gr.Video(label="Styled Video"),
|
164 |
title="Optical Flow Style Propagation",
|
165 |
description=(
|
166 |
+
"Upload a video and a stylized keyframe image. The style from the keyframe is propagated "
|
167 |
+
"across the video using optical flow and warping. The resulting video is re-encoded to be "
|
168 |
+
"web-friendly."
|
169 |
)
|
170 |
)
|
171 |
|