Spaces:
Sleeping
Sleeping
import cv2 | |
import numpy as np | |
import csv | |
import math | |
import torch | |
import tempfile | |
import os | |
import gradio as gr | |
import time | |
import io | |
from contextlib import redirect_stdout | |
# Set up device for torch | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"[INFO] Using device: {device}") | |
# Try to load the RAFT model from torch.hub. | |
# If it fails, we fall back to OpenCV optical flow. | |
try: | |
print("[INFO] Attempting to load RAFT model from torch.hub...") | |
raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True) | |
raft_model = raft_model.to(device) | |
raft_model.eval() | |
print("[INFO] RAFT model loaded successfully.") | |
except Exception as e: | |
print("[ERROR] Error loading RAFT model:", e) | |
print("[INFO] Falling back to OpenCV Farneback optical flow.") | |
raft_model = None | |
def generate_motion_csv(video_file, output_csv=None): | |
""" | |
Generates a CSV file with motion data (columns: frame, mag, ang, zoom) from an input video. | |
Uses RAFT if available, otherwise falls back to OpenCV's Farneback optical flow. | |
""" | |
start_time = time.time() | |
if output_csv is None: | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv') | |
output_csv = temp_file.name | |
temp_file.close() | |
cap = cv2.VideoCapture(video_file) | |
if not cap.isOpened(): | |
raise ValueError("[ERROR] Could not open video file for CSV generation.") | |
print(f"[INFO] Generating motion CSV for video: {video_file}") | |
with open(output_csv, 'w', newline='') as csvfile: | |
fieldnames = ['frame', 'mag', 'ang', 'zoom'] | |
writer = csv.DictWriter(csvfile, fieldnames=fieldnames) | |
writer.writeheader() | |
ret, first_frame = cap.read() | |
if not ret: | |
raise ValueError("[ERROR] Cannot read first frame from video.") | |
if raft_model is not None: | |
first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) | |
prev_tensor = torch.from_numpy(first_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0 | |
prev_tensor = prev_tensor.to(device) | |
print("[INFO] Using RAFT model for optical flow computation.") | |
else: | |
prev_gray = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY) | |
print("[INFO] Using OpenCV Farneback optical flow for computation.") | |
frame_idx = 1 | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
print(f"[INFO] Total frames to process: {total_frames}") | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if raft_model is not None: | |
curr_frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
curr_tensor = torch.from_numpy(curr_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0 | |
curr_tensor = curr_tensor.to(device) | |
with torch.no_grad(): | |
flow_low, flow_up = raft_model(prev_tensor, curr_tensor, iters=20, test_mode=True) | |
flow = flow_up[0].permute(1, 2, 0).cpu().numpy() | |
prev_tensor = curr_tensor.clone() | |
else: | |
curr_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
flow = cv2.calcOpticalFlowFarneback(prev_gray, curr_gray, None, | |
pyr_scale=0.5, levels=3, winsize=15, | |
iterations=3, poly_n=5, poly_sigma=1.2, flags=0) | |
prev_gray = curr_gray | |
# Compute median magnitude and angle of the optical flow. | |
mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1], angleInDegrees=True) | |
median_mag = np.median(mag) | |
median_ang = np.median(ang) | |
# Compute a "zoom factor": fraction of pixels moving away from the center. | |
h, w = flow.shape[:2] | |
center_x, center_y = w / 2, h / 2 | |
x_coords, y_coords = np.meshgrid(np.arange(w), np.arange(h)) | |
x_offset = x_coords - center_x | |
y_offset = y_coords - center_y | |
dot = flow[..., 0] * x_offset + flow[..., 1] * y_offset | |
zoom_factor = np.count_nonzero(dot > 0) / (w * h) | |
writer.writerow({ | |
'frame': frame_idx, | |
'mag': median_mag, | |
'ang': median_ang, | |
'zoom': zoom_factor | |
}) | |
if frame_idx % 10 == 0 or frame_idx == total_frames: | |
print(f"[INFO] Processed frame {frame_idx}/{total_frames}") | |
frame_idx += 1 | |
cap.release() | |
elapsed = time.time() - start_time | |
print(f"[INFO] Motion CSV generated: {output_csv} in {elapsed:.2f} seconds") | |
return output_csv | |
def read_motion_csv(csv_filename): | |
""" | |
Reads a motion CSV file (with columns: frame, mag, ang, zoom) and computes a cumulative | |
offset per frame for stabilization. | |
Returns: | |
A dictionary mapping frame numbers to (dx, dy) offsets. | |
""" | |
print(f"[INFO] Reading motion CSV: {csv_filename}") | |
motion_data = {} | |
cumulative_dx = 0.0 | |
cumulative_dy = 0.0 | |
with open(csv_filename, 'r') as csvfile: | |
reader = csv.DictReader(csvfile) | |
for row in reader: | |
frame_num = int(row['frame']) | |
mag = float(row['mag']) | |
ang = float(row['ang']) | |
rad = math.radians(ang) | |
dx = mag * math.cos(rad) | |
dy = mag * math.sin(rad) | |
cumulative_dx += dx | |
cumulative_dy += dy | |
motion_data[frame_num] = (-cumulative_dx, -cumulative_dy) | |
print("[INFO] Completed reading motion CSV.") | |
return motion_data | |
def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, output_file=None): | |
""" | |
Stabilizes the input video using motion data from the CSV file. | |
""" | |
start_time = time.time() | |
print(f"[INFO] Starting stabilization using CSV: {csv_file}") | |
motion_data = read_motion_csv(csv_file) | |
cap = cv2.VideoCapture(video_file) | |
if not cap.isOpened(): | |
raise ValueError("[ERROR] Could not open video file for stabilization.") | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
print(f"[INFO] Video properties - FPS: {fps}, Width: {width}, Height: {height}") | |
if output_file is None: | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') | |
output_file = temp_file.name | |
temp_file.close() | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(output_file, fourcc, fps, (width, height)) | |
frame_idx = 1 | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
print(f"[INFO] Total frames to stabilize: {total_frames}") | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Optionally apply zoom (resize and center-crop) | |
if zoom != 1.0: | |
zoomed_frame = cv2.resize(frame, None, fx=zoom, fy=zoom, interpolation=cv2.INTER_LINEAR) | |
zoomed_h, zoomed_w = zoomed_frame.shape[:2] | |
start_x = max((zoomed_w - width) // 2, 0) | |
start_y = max((zoomed_h - height) // 2, 0) | |
frame = zoomed_frame[start_y:start_y+height, start_x:start_x+width] | |
dx, dy = motion_data.get(frame_idx, (0, 0)) | |
transform = np.array([[1, 0, dx], | |
[0, 1, dy]], dtype=np.float32) | |
stabilized_frame = cv2.warpAffine(frame, transform, (width, height)) | |
out.write(stabilized_frame) | |
if frame_idx % 10 == 0 or frame_idx == total_frames: | |
print(f"[INFO] Stabilized frame {frame_idx}/{total_frames}") | |
frame_idx += 1 | |
cap.release() | |
out.release() | |
elapsed = time.time() - start_time | |
print(f"[INFO] Stabilized video saved to: {output_file} in {elapsed:.2f} seconds") | |
return output_file | |
def process_video_ai(video_file, zoom): | |
""" | |
Gradio interface function: | |
- Generates motion data (CSV) from the input video using an AI model (RAFT if available, else Farneback). | |
- Stabilizes the video based on the generated motion data. | |
Returns: | |
Tuple containing the original video file path, the stabilized video file path, and log output. | |
""" | |
log_buffer = io.StringIO() | |
with redirect_stdout(log_buffer): | |
if isinstance(video_file, dict): | |
video_file = video_file.get("name", None) | |
if video_file is None: | |
raise ValueError("[ERROR] Please upload a video file.") | |
print("[INFO] Starting AI-powered video processing...") | |
csv_file = generate_motion_csv(video_file) | |
stabilized_path = stabilize_video_using_csv(video_file, csv_file, zoom=zoom) | |
print("[INFO] Video processing complete.") | |
logs = log_buffer.getvalue() | |
return video_file, stabilized_path, logs | |
# Build the Gradio UI. | |
with gr.Blocks() as demo: | |
gr.Markdown("# AI-Powered Video Stabilization") | |
gr.Markdown("Upload a video and select a zoom factor. The system will generate motion data using an AI model (RAFT if available) and then stabilize the video.") | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video(label="Input Video") | |
zoom_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.1, value=1.0, label="Zoom Factor") | |
process_button = gr.Button("Process Video") | |
with gr.Column(): | |
original_video = gr.Video(label="Original Video") | |
stabilized_video = gr.Video(label="Stabilized Video") | |
logs_output = gr.Textbox(label="Logs", lines=10) | |
process_button.click( | |
fn=process_video_ai, | |
inputs=[video_input, zoom_slider], | |
outputs=[original_video, stabilized_video, logs_output] | |
) | |
demo.launch() | |