Spaces:
Runtime error
Runtime error
from xai import generate_grad_cam_overlay | |
import torch | |
import gradio as gr | |
import spaces | |
import shutil | |
import gc | |
from pathlib import Path | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
from transformers import AutoModelForImageClassification | |
from timm.data.transforms_factory import create_transform | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
MODEL_ID = "alessiopittiglio/MambaVision-FFPP" | |
NORM_MEAN = (0.485, 0.456, 0.406) | |
NORM_STD = (0.229, 0.224, 0.225) | |
NUM_FRAMES_TO_PROCESS = 15 | |
IMG_HEIGHT = 720 | |
IMG_WEIGHT = 540 | |
model = AutoModelForImageClassification.from_pretrained( | |
MODEL_ID, | |
trust_remote_code=True, | |
) | |
model.eval() | |
transform = create_transform( | |
input_size=(3, IMG_HEIGHT, IMG_WEIGHT), | |
is_training=False, | |
crop_mode="center", | |
crop_pct=1.0, | |
mean=NORM_MEAN, | |
std=NORM_STD, | |
) | |
def start_session(request: gr.Request): | |
session_hash = request.session_hash | |
session_dir = Path(f"/tmp/{session_hash}") | |
session_dir.mkdir(parents=True, exist_ok=True) | |
print(f"Session with hash {session_hash} started.") | |
return session_dir.as_posix() | |
def end_session(request: gr.Request): | |
session_hash = request.session_hash | |
session_dir = Path(f"/tmp/{session_hash}") | |
if session_dir.exists(): | |
shutil.rmtree(session_dir) | |
print(f"Session with hash {session_hash} ended.") | |
def extract_frames(video_filepath_temp: str, num_frames: int = 10): | |
frames = [] | |
cap = cv2.VideoCapture(video_filepath_temp) | |
if not cap.isOpened(): | |
print(f"Error: Could not open video file {video_filepath_temp}") | |
return frames | |
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
if length < 1: | |
cap.release() | |
return frames | |
actual_num_frames = min(num_frames, length) | |
indices = np.linspace(0, length - 1, actual_num_frames).astype(int) | |
for idx in indices: | |
cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
ret, frame = cap.read() | |
if not ret or frame is None: | |
continue | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frames.append(Image.fromarray(frame_rgb)) | |
cap.release() | |
return frames | |
def predict_deepfake(video_path, session_dir, cam_method): | |
if not video_path: | |
gr.Error("No video uploaded.", duration=None) | |
return [], [] | |
grad_cam_output = None | |
target_layer = None | |
original_path_name = Path(video_path).name | |
video_name = Path(video_path).stem | |
try: | |
gr.Info(f"Loading video file: {video_name}", duration=2) | |
frames = extract_frames(video_path, num_frames=NUM_FRAMES_TO_PROCESS) | |
except Exception as load_e: | |
gr.Error( | |
f"Failed to load video file {original_path_name}: {load_e}", duration=None | |
) | |
try: | |
model.to(device) | |
gr.Info(f"Processing {original_path_name} on {device}...", duration=2) | |
batch = torch.stack([transform(frame) for frame in frames]) # (N, C, H, W) | |
batch = batch.to(device, dtype=torch.float32) | |
with torch.no_grad(): | |
output = model(batch) # (N, num_classes) | |
logits = output["logits"] | |
probs = torch.softmax(logits, dim=1) # (N, num_classes) | |
probs_fake = probs[:, 0] # (N,) | |
probs_real = probs[:, 1] | |
avg_prob_fake = probs_fake.mean().item() | |
avg_prob_real = probs_real.mean().item() | |
predicted_label = "FAKE" if avg_prob_fake >= 0.5 else "REAL" | |
result = ( | |
f"**Prediction:** {predicted_label}\n" | |
f"**Confidence FAKE:** {avg_prob_fake:.2%}\n" | |
f"**Confidence REAL:** {avg_prob_real:.2%}\n" | |
) | |
try: | |
best_frame_idx = torch.argmax(probs_fake).item() | |
selected_frame = frames[best_frame_idx] | |
input_tensor = transform(selected_frame).unsqueeze(0).to(device) | |
model_for_xai = model.model | |
target_layer = model_for_xai.norm | |
class_to_explain = 0 if predicted_label == "FAKE" else 1 | |
grad_cam_output = generate_grad_cam_overlay( | |
model=model_for_xai, | |
target_layer=target_layer, | |
input_tensor=input_tensor, | |
original_frames=[selected_frame], | |
target_class_idx=class_to_explain, | |
cam_method=cam_method, | |
) | |
except Exception as xai_e: | |
gr.Error(f"Failed to generate XAI output: {xai_e}", duration=None) | |
print(f"Error during XAI generation: {xai_e}") | |
gr.Info("Prediction complete.", duration=2) | |
return result, grad_cam_output | |
except torch.cuda.OutOfMemoryError as e: | |
error_msg = "CUDA out of memory. Please try a shorter audio or reduce GPU load." | |
print(f"CUDA OutOfMemoryError: {e}") | |
gr.Error(error_msg, duration=None) | |
finally: | |
try: | |
if "model" in locals() and hasattr(model, "cpu"): | |
if device == "cuda": | |
model.cpu() | |
gc.collect() | |
if device == "cuda": | |
torch.cuda.empty_cache() | |
except Exception as cleanup_e: | |
print(f"Error during model cleanup: {cleanup_e}") | |
gr.Warning(f"Issue during model cleanup: {cleanup_e}", duration=5) | |
article = ( | |
"<p style='font-size: 1.1em;'>" | |
"This demo showcases a system for analyzing video authenticity, developed as part of a project for the " | |
"Ethics in Artificial Intelligence course (Artificial Intelligence, University of Bologna)." | |
"</p>" | |
"<p><strong style='color: red; font-size: 1.2em;'>Key Features:</strong></p>" | |
"<ul style='font-size: 1.1em;'>" | |
" <li>Classifies videos as REAL or FAKE</li>" | |
" <li>Provides an estimated probability for the prediction</li>" | |
" <li>Grad-CAM visualization to highlight image regions influencing the model's decision on a sample frame</li>" | |
"</ul>" | |
"<p style='text-align: center; margin-top: 1em;'>" | |
"<p><strong>Disclaimer:</strong> This is a research and educational prototype. " | |
"Results may vary and should not be considered definitive evidence of manipulation." | |
"</p>" | |
) | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
f"<h1 style='text-align: center; margin: 0 auto;'>👁️ Deepfake Video Detector</h1>" | |
) | |
gr.HTML(article) | |
current_video_path_state = gr.State(None) | |
session_dir = gr.State() | |
demo.load(start_session, outputs=[session_dir]) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
file_input = gr.Video(sources=["upload"], label="Upload Video File") | |
cam_dropdown = gr.Dropdown( | |
choices=["GradCAM", "GradCAM++"], | |
value="GradCAM", | |
label="XAI Method", | |
) | |
predict_btn = gr.Button("Analyze Video", variant="primary") | |
with gr.Column(scale=2): | |
xai_output = gr.Image(label="XAI Heatmap (e.g., Grad-CAM)", type="pil") | |
gr.Markdown( | |
"<p><strong style='color: #FF0000; font-size: 1.2em;'>Prediction Results</strong></p>" | |
) | |
results_output = gr.Markdown(label="Prediction Results", line_breaks=True) | |
predict_btn.click( | |
fn=predict_deepfake, | |
inputs=[file_input, session_dir, cam_dropdown], | |
outputs=[results_output, xai_output], | |
api_name="predict_deepfake", | |
) | |
demo.unload(end_session) | |
if __name__ == "__main__": | |
print("Launching Gradio Demo...") | |
demo.queue() | |
demo.launch() | |