import cv2 import numpy as np import csv import math import torch import tempfile import os import gradio as gr # Set up device for torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Try to load the RAFT model from torch.hub. # If it fails (e.g. due to repository structure changes), we will fall back to OpenCV optical flow. try: # The trust_repo parameter might prompt for confirmation; set it to True. raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True) raft_model = raft_model.to(device) raft_model.eval() print("RAFT model loaded successfully.") except Exception as e: print("Error loading RAFT model:", e) print("Falling back to OpenCV optical flow for motion CSV generation.") raft_model = None def generate_motion_csv(video_file, output_csv=None): """ Generates a CSV file with motion data (columns: frame, mag, ang, zoom) from an input video. If the RAFT model is available, it uses it to compute optical flow; otherwise, it falls back to OpenCV's Farneback optical flow. Args: video_file (str): Path to the input video. output_csv (str): Optional output CSV file path. If None, a temporary file is created. Returns: output_csv (str): Path to the generated CSV file. """ if output_csv is None: temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv') output_csv = temp_file.name temp_file.close() cap = cv2.VideoCapture(video_file) if not cap.isOpened(): raise ValueError("Could not open video file for CSV generation.") with open(output_csv, 'w', newline='') as csvfile: fieldnames = ['frame', 'mag', 'ang', 'zoom'] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() ret, first_frame = cap.read() if not ret: raise ValueError("Cannot read first frame from video.") if raft_model is not None: # Convert the first frame to RGB and then to a torch tensor. first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) prev_tensor = torch.from_numpy(first_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0 prev_tensor = prev_tensor.to(device) else: prev_gray = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY) frame_idx = 1 while True: ret, frame = cap.read() if not ret: break if raft_model is not None: curr_frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) curr_tensor = torch.from_numpy(curr_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0 curr_tensor = curr_tensor.to(device) with torch.no_grad(): flow_low, flow_up = raft_model(prev_tensor, curr_tensor, iters=20, test_mode=True) flow = flow_up[0].permute(1, 2, 0).cpu().numpy() prev_tensor = curr_tensor.clone() else: curr_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) flow = cv2.calcOpticalFlowFarneback(prev_gray, curr_gray, None, pyr_scale=0.5, levels=3, winsize=15, iterations=3, poly_n=5, poly_sigma=1.2, flags=0) prev_gray = curr_gray # Compute median magnitude and angle of the optical flow. mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1], angleInDegrees=True) median_mag = np.median(mag) median_ang = np.median(ang) # Compute a "zoom factor": fraction of pixels moving away from the center. h, w = flow.shape[:2] center_x, center_y = w / 2, h / 2 x_coords, y_coords = np.meshgrid(np.arange(w), np.arange(h)) x_offset = x_coords - center_x y_offset = y_coords - center_y dot = flow[..., 0] * x_offset + flow[..., 1] * y_offset zoom_factor = np.count_nonzero(dot > 0) / (w * h) writer.writerow({ 'frame': frame_idx, 'mag': median_mag, 'ang': median_ang, 'zoom': zoom_factor }) frame_idx += 1 cap.release() print(f"Motion CSV generated: {output_csv}") return output_csv def read_motion_csv(csv_filename): """ Reads a motion CSV file (with columns: frame, mag, ang, zoom) and computes a cumulative offset per frame (the negative cumulative displacement) for stabilization. Returns: A dictionary mapping frame numbers to (dx, dy) offsets. """ motion_data = {} cumulative_dx = 0.0 cumulative_dy = 0.0 with open(csv_filename, 'r') as csvfile: reader = csv.DictReader(csvfile) for row in reader: frame_num = int(row['frame']) mag = float(row['mag']) ang = float(row['ang']) rad = math.radians(ang) dx = mag * math.cos(rad) dy = mag * math.sin(rad) cumulative_dx += dx cumulative_dy += dy motion_data[frame_num] = (-cumulative_dx, -cumulative_dy) return motion_data def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, output_file=None): """ Stabilizes the input video using motion data from the CSV file. Args: video_file (str): Path to the input video. csv_file (str): Path to the motion CSV file. zoom (float): Zoom factor to apply before stabilization (default: 1.0). output_file (str): Path for the output stabilized video. If None, a temporary file is created. Returns: output_file (str): Path to the stabilized video file. """ motion_data = read_motion_csv(csv_file) cap = cv2.VideoCapture(video_file) if not cap.isOpened(): raise ValueError("Could not open video file for stabilization.") fps = cap.get(cv2.CAP_PROP_FPS) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) if output_file is None: temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') output_file = temp_file.name temp_file.close() fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_file, fourcc, fps, (width, height)) frame_idx = 1 while True: ret, frame = cap.read() if not ret: break if zoom != 1.0: zoomed_frame = cv2.resize(frame, None, fx=zoom, fy=zoom, interpolation=cv2.INTER_LINEAR) zoomed_h, zoomed_w = zoomed_frame.shape[:2] start_x = max((zoomed_w - width) // 2, 0) start_y = max((zoomed_h - height) // 2, 0) frame = zoomed_frame[start_y:start_y+height, start_x:start_x+width] dx, dy = motion_data.get(frame_idx, (0, 0)) transform = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32) stabilized_frame = cv2.warpAffine(frame, transform, (width, height)) out.write(stabilized_frame) frame_idx += 1 cap.release() out.release() print(f"Stabilized video saved to: {output_file}") return output_file def process_video_ai(video_file, zoom): """ Gradio interface function: - Generates motion data (CSV) from the input video using an AI model (RAFT, if available). - Stabilizes the video based on the generated motion data. Returns: Tuple containing the original video file path and the stabilized video file path. """ if isinstance(video_file, dict): video_file = video_file.get("name", None) if video_file is None: raise ValueError("Please upload a video file.") # Generate motion CSV using the AI model (or fallback) for optical flow. csv_file = generate_motion_csv(video_file) # Stabilize the video using the generated CSV. stabilized_path = stabilize_video_using_csv(video_file, csv_file, zoom=zoom) return video_file, stabilized_path # Build the Gradio UI. with gr.Blocks() as demo: gr.Markdown("# AI-Powered Video Stabilization") gr.Markdown("Upload a video and select a zoom factor. The system will automatically generate motion data (video.flow.csv) using an AI model (RAFT, if available) and then stabilize the video.") with gr.Row(): with gr.Column(): video_input = gr.Video(label="Input Video") zoom_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.1, value=1.0, label="Zoom Factor") process_button = gr.Button("Process Video") with gr.Column(): original_video = gr.Video(label="Original Video") stabilized_video = gr.Video(label="Stabilized Video") process_button.click( fn=process_video_ai, inputs=[video_input, zoom_slider], outputs=[original_video, stabilized_video] ) demo.launch()