EBSynth / app.py
NihalGazi's picture
Update app.py
7ffd52f verified
raw
history blame
6.43 kB
import cv2
import numpy as np
import os
import tempfile
from tqdm import tqdm
import gradio as gr
import ffmpeg
# -----------------------------
# Function to extract frames from the video.
# -----------------------------
def extract_frames(video_path):
"""
Opens the video file and extracts all frames into a list.
Logic:
- Use cv2.VideoCapture to read the video frame-by-frame.
"""
cap = cv2.VideoCapture(video_path)
frames = []
while True:
ret, frame = cap.read() # Read a frame.
if not ret:
break # Stop if no more frames.
frames.append(frame)
cap.release()
return frames
# -----------------------------
# Function to propagate the style image using optical flow.
# -----------------------------
def apply_style_propagation(frames, style_image_path):
"""
Applies the style from the provided image onto each video frame.
Logic:
- Load and resize the style image to match the video dimensions.
- Use the style image as the starting point.
- For each subsequent frame, compute the dense optical flow between the previous and current frame.
- Warp the previous styled frame so that the style follows the motion.
"""
# Load the style image and resize to match frame dimensions.
style_image = cv2.imread(style_image_path)
h, w = frames[0].shape[:2]
style_image = cv2.resize(style_image, (w, h))
# The first styled frame is the style image.
styled_frames = [style_image]
# Convert the first frame to grayscale.
prev_gray = cv2.cvtColor(frames[0], cv2.COLOR_BGR2GRAY)
# Process subsequent frames.
for i in tqdm(range(1, len(frames)), desc="Propagating style"):
curr_gray = cv2.cvtColor(frames[i], cv2.COLOR_BGR2GRAY)
# Compute dense optical flow using Farneback's method.
flow = cv2.calcOpticalFlowFarneback(
prev_gray, curr_gray, None,
pyr_scale=0.5, levels=3, winsize=15,
iterations=3, poly_n=5, poly_sigma=1.2, flags=0
)
# Create a coordinate grid.
grid_x, grid_y = np.meshgrid(np.arange(w), np.arange(h))
map_x = (grid_x + flow[..., 0]).astype(np.float32)
map_y = (grid_y + flow[..., 1]).astype(np.float32)
# Warp the last styled frame.
warped_styled = cv2.remap(styled_frames[-1], map_x, map_y, interpolation=cv2.INTER_LINEAR)
styled_frames.append(warped_styled)
# Update previous frame.
prev_gray = curr_gray
return styled_frames
# -----------------------------
# Function to save video frames using OpenCV.
# -----------------------------
def save_video_cv2(frames, output_path, fps=30):
"""
Saves a list of frames as a video file.
Logic:
- Uses cv2.VideoWriter with codec 'mp4v' to create a temporary video file.
"""
h, w, _ = frames[0].shape
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
for frame in frames:
out.write(frame)
out.release()
# -----------------------------
# Main processing function for Gradio.
# -----------------------------
def process_video(video_file, style_image_file, fps=30):
"""
Processes the video by applying the style image via optical flow.
Then re-encodes the video using FFmpeg (H.264, yuv420p) for web compatibility.
Inputs:
- video_file: Uploaded video file.
- style_image_file: Uploaded stylized keyframe image.
- fps: Output frames per second.
Returns:
- Path to the re-encoded, web-playable video.
"""
# Get the video file path.
video_path = video_file if isinstance(video_file, str) else video_file["name"]
# Process style image input: if it's a numpy array, save it to a temporary file.
if isinstance(style_image_file, str):
style_image_path = style_image_file
elif isinstance(style_image_file, dict) and "name" in style_image_file:
style_image_path = style_image_file["name"]
elif isinstance(style_image_file, np.ndarray):
tmp_style_path = os.path.join(tempfile.gettempdir(), "temp_style_image.jpeg")
# Convert from RGB (Gradio) to BGR (OpenCV)
cv2.imwrite(tmp_style_path, cv2.cvtColor(style_image_file, cv2.COLOR_RGB2BGR))
style_image_path = tmp_style_path
else:
return "Error: Unsupported style image format."
# Extract frames from the video.
frames = extract_frames(video_path)
if not frames:
return "Error: No frames extracted from the video."
# Propagate the style across video frames.
styled_frames = apply_style_propagation(frames, style_image_path)
# Use a temporary directory for processing.
with tempfile.TemporaryDirectory() as tmpdir:
# Save the raw styled video using OpenCV.
temp_video_path = os.path.join(tmpdir, "temp_video.mp4")
save_video_cv2(styled_frames, temp_video_path, fps=fps)
# Re-encode using FFmpeg to produce a web-friendly video.
output_video_path = os.path.join(tmpdir, "output_video.mp4")
try:
(
ffmpeg
.input(temp_video_path)
.output(output_video_path, vcodec='libx264', pix_fmt='yuv420p', r=fps)
.run(overwrite_output=True, quiet=True)
)
except ffmpeg.Error as e:
print("FFmpeg error:", e)
return "Error during video re-encoding."
return output_video_path
# -----------------------------
# Define the Gradio Interface.
# -----------------------------
iface = gr.Interface(
fn=process_video,
inputs=[
gr.Video(label="Input Video (v.mp4)"),
gr.Image(label="Stylized Keyframe (a.jpeg)"),
gr.Slider(minimum=1, maximum=60, step=1, value=30, label="Output FPS")
],
outputs=gr.Video(label="Styled Video"),
title="Optical Flow Style Propagation",
description=(
"Upload a video and a stylized keyframe image. The style from the keyframe is propagated "
"across the video using optical flow and warping. The resulting video is re-encoded to be "
"web-friendly."
)
)
# -----------------------------
# Launch the Gradio App with public sharing enabled.
# -----------------------------
if __name__ == "__main__":
iface.launch(share=True)