from xai import generate_grad_cam_overlay import torch import gradio as gr import spaces import shutil import gc from pathlib import Path import cv2 import numpy as np from PIL import Image from transformers import AutoModelForImageClassification from timm.data.transforms_factory import create_transform device = "cuda" if torch.cuda.is_available() else "cpu" MODEL_ID = "alessiopittiglio/MambaVision-FFPP" NORM_MEAN = (0.485, 0.456, 0.406) NORM_STD = (0.229, 0.224, 0.225) NUM_FRAMES_TO_PROCESS = 15 IMG_HEIGHT = 720 IMG_WEIGHT = 540 model = AutoModelForImageClassification.from_pretrained( MODEL_ID, trust_remote_code=True, ) model.eval() transform = create_transform( input_size=(3, IMG_HEIGHT, IMG_WEIGHT), is_training=False, crop_mode="center", crop_pct=1.0, mean=NORM_MEAN, std=NORM_STD, ) def start_session(request: gr.Request): session_hash = request.session_hash session_dir = Path(f"/tmp/{session_hash}") session_dir.mkdir(parents=True, exist_ok=True) print(f"Session with hash {session_hash} started.") return session_dir.as_posix() def end_session(request: gr.Request): session_hash = request.session_hash session_dir = Path(f"/tmp/{session_hash}") if session_dir.exists(): shutil.rmtree(session_dir) print(f"Session with hash {session_hash} ended.") def extract_frames(video_filepath_temp: str, num_frames: int = 10): frames = [] cap = cv2.VideoCapture(video_filepath_temp) if not cap.isOpened(): print(f"Error: Could not open video file {video_filepath_temp}") return frames length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if length < 1: cap.release() return frames actual_num_frames = min(num_frames, length) indices = np.linspace(0, length - 1, actual_num_frames).astype(int) for idx in indices: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if not ret or frame is None: continue frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame_rgb)) cap.release() return frames @spaces.GPU def predict_deepfake(video_path, session_dir, cam_method): if not video_path: gr.Error("No video uploaded.", duration=None) return [], [] grad_cam_output = None target_layer = None original_path_name = Path(video_path).name video_name = Path(video_path).stem try: gr.Info(f"Loading video file: {video_name}", duration=2) frames = extract_frames(video_path, num_frames=NUM_FRAMES_TO_PROCESS) except Exception as load_e: gr.Error( f"Failed to load video file {original_path_name}: {load_e}", duration=None ) try: model.to(device) gr.Info(f"Processing {original_path_name} on {device}...", duration=2) batch = torch.stack([transform(frame) for frame in frames]) # (N, C, H, W) batch = batch.to(device, dtype=torch.float32) with torch.no_grad(): output = model(batch) # (N, num_classes) logits = output["logits"] probs = torch.softmax(logits, dim=1) # (N, num_classes) probs_fake = probs[:, 0] # (N,) probs_real = probs[:, 1] avg_prob_fake = probs_fake.mean().item() avg_prob_real = probs_real.mean().item() predicted_label = "FAKE" if avg_prob_fake >= 0.5 else "REAL" result = ( f"**Prediction:** {predicted_label}\n" f"**Confidence FAKE:** {avg_prob_fake:.2%}\n" f"**Confidence REAL:** {avg_prob_real:.2%}\n" ) try: best_frame_idx = torch.argmax(probs_fake).item() selected_frame = frames[best_frame_idx] input_tensor = transform(selected_frame).unsqueeze(0).to(device) model_for_xai = model.model target_layer = model_for_xai.norm class_to_explain = 0 if predicted_label == "FAKE" else 1 grad_cam_output = generate_grad_cam_overlay( model=model_for_xai, target_layer=target_layer, input_tensor=input_tensor, original_frames=[selected_frame], target_class_idx=class_to_explain, cam_method=cam_method, ) except Exception as xai_e: gr.Error(f"Failed to generate XAI output: {xai_e}", duration=None) print(f"Error during XAI generation: {xai_e}") gr.Info("Prediction complete.", duration=2) return result, grad_cam_output except torch.cuda.OutOfMemoryError as e: error_msg = "CUDA out of memory. Please try a shorter audio or reduce GPU load." print(f"CUDA OutOfMemoryError: {e}") gr.Error(error_msg, duration=None) finally: try: if "model" in locals() and hasattr(model, "cpu"): if device == "cuda": model.cpu() gc.collect() if device == "cuda": torch.cuda.empty_cache() except Exception as cleanup_e: print(f"Error during model cleanup: {cleanup_e}") gr.Warning(f"Issue during model cleanup: {cleanup_e}", duration=5) article = ( "

" "This demo showcases a system for analyzing video authenticity, developed as part of a project for the " "Ethics in Artificial Intelligence course (Artificial Intelligence, University of Bologna)." "

" "

Key Features:

" "" "

" "

Disclaimer: This is a research and educational prototype. " "Results may vary and should not be considered definitive evidence of manipulation." "

" ) with gr.Blocks() as demo: gr.Markdown( f"

👁️ Deepfake Video Detector

" ) gr.HTML(article) current_video_path_state = gr.State(None) session_dir = gr.State() demo.load(start_session, outputs=[session_dir]) with gr.Row(): with gr.Column(scale=1): file_input = gr.Video(sources=["upload"], label="Upload Video File") cam_dropdown = gr.Dropdown( choices=["GradCAM", "GradCAM++"], value="GradCAM", label="XAI Method", ) predict_btn = gr.Button("Analyze Video", variant="primary") with gr.Column(scale=2): xai_output = gr.Image(label="XAI Heatmap (e.g., Grad-CAM)", type="pil") gr.Markdown( "

Prediction Results

" ) results_output = gr.Markdown(label="Prediction Results", line_breaks=True) predict_btn.click( fn=predict_deepfake, inputs=[file_input, session_dir, cam_dropdown], outputs=[results_output, xai_output], api_name="predict_deepfake", ) demo.unload(end_session) if __name__ == "__main__": print("Launching Gradio Demo...") demo.queue() demo.launch()