SpyC0der77's picture
Update app.py
89bc003 verified
raw
history blame
9.24 kB
import cv2
import numpy as np
import csv
import math
import torch
import tempfile
import os
import gradio as gr
# Set up device for torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Try to load the RAFT model from torch.hub.
# If it fails (e.g. due to repository structure changes), we will fall back to OpenCV optical flow.
try:
# The trust_repo parameter might prompt for confirmation; set it to True.
raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True)
raft_model = raft_model.to(device)
raft_model.eval()
print("RAFT model loaded successfully.")
except Exception as e:
print("Error loading RAFT model:", e)
print("Falling back to OpenCV optical flow for motion CSV generation.")
raft_model = None
def generate_motion_csv(video_file, output_csv=None):
"""
Generates a CSV file with motion data (columns: frame, mag, ang, zoom) from an input video.
If the RAFT model is available, it uses it to compute optical flow; otherwise, it falls back to
OpenCV's Farneback optical flow.
Args:
video_file (str): Path to the input video.
output_csv (str): Optional output CSV file path. If None, a temporary file is created.
Returns:
output_csv (str): Path to the generated CSV file.
"""
if output_csv is None:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv')
output_csv = temp_file.name
temp_file.close()
cap = cv2.VideoCapture(video_file)
if not cap.isOpened():
raise ValueError("Could not open video file for CSV generation.")
with open(output_csv, 'w', newline='') as csvfile:
fieldnames = ['frame', 'mag', 'ang', 'zoom']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
ret, first_frame = cap.read()
if not ret:
raise ValueError("Cannot read first frame from video.")
if raft_model is not None:
# Convert the first frame to RGB and then to a torch tensor.
first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
prev_tensor = torch.from_numpy(first_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0
prev_tensor = prev_tensor.to(device)
else:
prev_gray = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY)
frame_idx = 1
while True:
ret, frame = cap.read()
if not ret:
break
if raft_model is not None:
curr_frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
curr_tensor = torch.from_numpy(curr_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0
curr_tensor = curr_tensor.to(device)
with torch.no_grad():
flow_low, flow_up = raft_model(prev_tensor, curr_tensor, iters=20, test_mode=True)
flow = flow_up[0].permute(1, 2, 0).cpu().numpy()
prev_tensor = curr_tensor.clone()
else:
curr_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
flow = cv2.calcOpticalFlowFarneback(prev_gray, curr_gray, None,
pyr_scale=0.5, levels=3, winsize=15,
iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
prev_gray = curr_gray
# Compute median magnitude and angle of the optical flow.
mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1], angleInDegrees=True)
median_mag = np.median(mag)
median_ang = np.median(ang)
# Compute a "zoom factor": fraction of pixels moving away from the center.
h, w = flow.shape[:2]
center_x, center_y = w / 2, h / 2
x_coords, y_coords = np.meshgrid(np.arange(w), np.arange(h))
x_offset = x_coords - center_x
y_offset = y_coords - center_y
dot = flow[..., 0] * x_offset + flow[..., 1] * y_offset
zoom_factor = np.count_nonzero(dot > 0) / (w * h)
writer.writerow({
'frame': frame_idx,
'mag': median_mag,
'ang': median_ang,
'zoom': zoom_factor
})
frame_idx += 1
cap.release()
print(f"Motion CSV generated: {output_csv}")
return output_csv
def read_motion_csv(csv_filename):
"""
Reads a motion CSV file (with columns: frame, mag, ang, zoom) and computes a cumulative
offset per frame (the negative cumulative displacement) for stabilization.
Returns:
A dictionary mapping frame numbers to (dx, dy) offsets.
"""
motion_data = {}
cumulative_dx = 0.0
cumulative_dy = 0.0
with open(csv_filename, 'r') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
frame_num = int(row['frame'])
mag = float(row['mag'])
ang = float(row['ang'])
rad = math.radians(ang)
dx = mag * math.cos(rad)
dy = mag * math.sin(rad)
cumulative_dx += dx
cumulative_dy += dy
motion_data[frame_num] = (-cumulative_dx, -cumulative_dy)
return motion_data
def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, output_file=None):
"""
Stabilizes the input video using motion data from the CSV file.
Args:
video_file (str): Path to the input video.
csv_file (str): Path to the motion CSV file.
zoom (float): Zoom factor to apply before stabilization (default: 1.0).
output_file (str): Path for the output stabilized video. If None, a temporary file is created.
Returns:
output_file (str): Path to the stabilized video file.
"""
motion_data = read_motion_csv(csv_file)
cap = cv2.VideoCapture(video_file)
if not cap.isOpened():
raise ValueError("Could not open video file for stabilization.")
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
if output_file is None:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
output_file = temp_file.name
temp_file.close()
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_file, fourcc, fps, (width, height))
frame_idx = 1
while True:
ret, frame = cap.read()
if not ret:
break
if zoom != 1.0:
zoomed_frame = cv2.resize(frame, None, fx=zoom, fy=zoom, interpolation=cv2.INTER_LINEAR)
zoomed_h, zoomed_w = zoomed_frame.shape[:2]
start_x = max((zoomed_w - width) // 2, 0)
start_y = max((zoomed_h - height) // 2, 0)
frame = zoomed_frame[start_y:start_y+height, start_x:start_x+width]
dx, dy = motion_data.get(frame_idx, (0, 0))
transform = np.array([[1, 0, dx],
[0, 1, dy]], dtype=np.float32)
stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
out.write(stabilized_frame)
frame_idx += 1
cap.release()
out.release()
print(f"Stabilized video saved to: {output_file}")
return output_file
def process_video_ai(video_file, zoom):
"""
Gradio interface function:
- Generates motion data (CSV) from the input video using an AI model (RAFT, if available).
- Stabilizes the video based on the generated motion data.
Returns:
Tuple containing the original video file path and the stabilized video file path.
"""
if isinstance(video_file, dict):
video_file = video_file.get("name", None)
if video_file is None:
raise ValueError("Please upload a video file.")
# Generate motion CSV using the AI model (or fallback) for optical flow.
csv_file = generate_motion_csv(video_file)
# Stabilize the video using the generated CSV.
stabilized_path = stabilize_video_using_csv(video_file, csv_file, zoom=zoom)
return video_file, stabilized_path
# Build the Gradio UI.
with gr.Blocks() as demo:
gr.Markdown("# AI-Powered Video Stabilization")
gr.Markdown("Upload a video and select a zoom factor. The system will automatically generate motion data (video.flow.csv) using an AI model (RAFT, if available) and then stabilize the video.")
with gr.Row():
with gr.Column():
video_input = gr.Video(label="Input Video")
zoom_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.1, value=1.0, label="Zoom Factor")
process_button = gr.Button("Process Video")
with gr.Column():
original_video = gr.Video(label="Original Video")
stabilized_video = gr.Video(label="Stabilized Video")
process_button.click(
fn=process_video_ai,
inputs=[video_input, zoom_slider],
outputs=[original_video, stabilized_video]
)
demo.launch()