viralplay / app.py
phitran's picture
Update app.py
f539c08 verified
import os
import shutil
import cv2
import spaces
import gradio as gr
from handlers import frame_handler_yolo as fh
from handlers import video_handler as vh
model_path = "yolov8n.pt" # YOLOv8 model path
@spaces.GPU(duration=150)
def process_video(video_file):
"""
Processes the uploaded video file by extracting key frames, cropping them, and generating a processed video.
"""
status_message = "Processing started..."
# Define output directories
output_folder = "output_data"
all_frames_folder = os.path.join(output_folder, "all_frames")
key_frames_folder = os.path.join(output_folder, "key_frames")
nonkey_frames_folder = os.path.join(output_folder, "nonkey_frames")
cropped_frames_folder = os.path.join(output_folder, "cropped_frames")
processed_video_path = os.path.join(output_folder, "processed_video.mp4")
print("Calling process_video function: Output folder:", output_folder)
# Clear output directory before processing
if os.path.exists(output_folder):
shutil.rmtree(output_folder)
os.makedirs(output_folder, exist_ok=True)
# Save uploaded video temporarily
video_path = os.path.join(output_folder, "input_video.mp4")
with open(video_file.name, "rb") as vf:
with open(video_path, "wb") as f:
f.write(vf.read())
video = cv2.VideoCapture(video_path)
# Check if the video is opened successfully
if not video.isOpened():
print(f"Error: Cannot open video file {video_path}")
return
# Get video properties
original_fps = int(video.get(cv2.CAP_PROP_FPS)) # Frames per second
video.release()
# Step 1: Extract all frames
status_message = "Extracting frames. Please wait...!"
yield status_message, None
vh.extract_frames_by_rate(video_path, all_frames_folder, original_fps)
# Step 2 - extract key frames
status_message = "Extracting key frames. Please wait...!"
yield status_message, None
fh.extract_key_frames(all_frames_folder, key_frames_folder, original_fps, model_path)
# Step 3 - cropping key frames while reserving key objects
status_message = "Cropping key frames. Please wait...!"
yield status_message, None
target_resolution = (360, 640) # Output resolution (9:16)
fh.crop_preserve_key_objects(key_frames_folder, cropped_frames_folder, model_path, target_resolution)
status_message = "Generating final video. Please wait...!"
yield status_message, None
# Step 4: Generate short video
target_frame_rate = 24 # standard frame per second is 24
vh.create_video_from_frames(cropped_frames_folder, processed_video_path, target_frame_rate, target_resolution)
status_message = "Processing complete!"
yield status_message, processed_video_path
# Gradio Blocks UI
with gr.Blocks() as demo:
gr.Markdown("## Generate short video for your football match")
gr.Markdown("Upload a video file. The app will extract key frames, crop them to fix 9:16 aspect ratio, "
"and generate a short video.")
gr.Markdown("Test data is auto loaded. Just click Proceed video button to see the result. Click x to upload your "
"own video file")
with gr.Row():
with gr.Column():
video_input = gr.File(label="Upload Video", type="filepath", file_types=["video"], file_count="single"
, height=145)
with gr.Column():
process_button = gr.Button("Process Video", variant="primary")
status_output = gr.Textbox(label="Status", interactive=False)
with gr.Row():
with gr.Column():
input_video_preview = gr.Video(label="Input Video (16:9) Preview", width=640, height=360)
with gr.Column():
video_output = gr.Video(label="Short Video (9:16) Generated", width=360, height=640)
def update_preview(video_path):
if video_path is None:
return None
return video_path
video_input.change(
fn=update_preview,
inputs=video_input,
outputs=input_video_preview
)
def update_preview(video_path):
if video_path is None:
return None
return video_path
video_input.change(
fn=update_preview,
inputs=video_input,
outputs=input_video_preview
)
# Set default video input and update the preview when app launches
def set_default_video():
default_video_path = "./input_data/football.mp4"
return default_video_path, default_video_path
demo.load(
fn=set_default_video,
inputs=[],
outputs=[video_input, input_video_preview]
)
process_button.click(process_video, inputs=video_input, outputs=[status_output, video_output])
if __name__ == "__main__":
demo.launch()