import gradio as gr
import cv2
import numpy as np
import tensorflow as tf
import os
import pennylane as qml

# Load the trained models
cnn_model = tf.keras.models.load_model('noq_c_model.h5')
qcnn_model = tf.keras.models.load_model('q_model.h5')

# Define the quanvolutional layer
n_layers = 3  # Number of quantum layers
dev = qml.device("default.qubit", wires=2)

# Set random number seed
np.random.seed(0)
rand_params = np.random.uniform(high=2 * np.pi, size=(n_layers, 2, 2))

@qml.qnode(dev)
def circuit(phi):
    for j in range(2):
        qml.RY(np.pi * phi[j], wires=j)
    for layer in range(n_layers):
        qml.templates.layers.RandomLayers(weights=rand_params[layer], wires=list(range(2)))
    return [qml.expval(qml.PauliZ(j)) for j in range(2)]

def quanv(image):
    out = np.zeros((14, 14, 2))
    for j in range(0, 28, 2):
        for k in range(0, 28, 2):
            q_results = circuit([
                image[j, k, 0],
                image[j, k + 1, 0]
            ])
            for c in range(2):
                out[j // 2, k // 2, c] = q_results[c]
    return out

# Directories containing example videos
examples_dir = 'examples'
original_dir = os.path.join(examples_dir, 'Original')
deepfake_roop_dir = os.path.join(examples_dir, 'DeepfakeRoop')
deepfake_web_dir = os.path.join(examples_dir, 'DeepfakeWeb')

# Function to get video paths from a directory
def get_video_paths(directory):
    videos = []
    for vid in os.listdir(directory):
        if vid.endswith('.mp4'):
            videos.append(os.path.join(directory, vid))
    return videos

examples_original = get_video_paths(original_dir)
examples_deepfake_roop = get_video_paths(deepfake_roop_dir)
examples_deepfake_web = get_video_paths(deepfake_web_dir)

# Map from example video path to label
example_videos_dict = {}
for vid in examples_original:
    example_videos_dict[vid] = 'Original'
for vid in examples_deepfake_roop:
    example_videos_dict[vid] = 'DeepfakeRoop'
for vid in examples_deepfake_web:
    example_videos_dict[vid] = 'DeepfakeWeb'

def process_frame(frame):
    # Convert to grayscale
    gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    # Resize to 28x28
    resized_frame = cv2.resize(gray_frame, (28, 28))
    # Normalize pixel values
    normalized_frame = resized_frame / 255.0
    # Add channel dimension
    normalized_frame = np.expand_dims(normalized_frame, axis=-1)
    # Apply quantum convolution
    q_frame = quanv(normalized_frame)
    # Reshape for model prediction
    q_frame = np.expand_dims(q_frame, axis=0)
    return q_frame

def process_video(video_path, true_label=None):
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    if fps == 0 or np.isnan(fps):
        fps = 30
    frame_interval = max(int(round(fps / 30)), 1)
    frame_count = 0
    total_frames = 0
    cnn_correct = 0
    qcnn_correct = 0
    cnn_class0 = 0
    cnn_class1 = 0
    qcnn_class0 = 0
    qcnn_class1 = 0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret or total_frames >= 30:
            break
        if frame_count % frame_interval == 0:
            # Process frame for cnn_model
            cnn_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            cnn_frame = cv2.resize(cnn_frame, (28, 28))
            cnn_frame = cnn_frame / 255.0
            cnn_frame = cnn_frame.reshape(1, 28, 28, 1)
            # Predict with cnn_model
            cnn_pred = cnn_model.predict(cnn_frame)
            cnn_label = (cnn_pred > 0.5).astype(int)
            if cnn_label == 0:
                cnn_class0 += 1
            else:
                cnn_class1 += 1
            if true_label is not None and cnn_label == true_label:
                cnn_correct += 1

            # Process frame for qcnn_model
            q_frame = process_frame(frame)
            # Predict with qcnn_model
            qcnn_pred = qcnn_model.predict(q_frame)
            qcnn_label = (qcnn_pred > 0.5).astype(int)
            if qcnn_label == 0:
                qcnn_class0 += 1
            else:
                qcnn_class1 += 1
            if true_label is not None and qcnn_label == true_label:
                qcnn_correct += 1

            total_frames += 1
        frame_count += 1
    cap.release()

    if total_frames > 0:
        cnn_class0_percent = (cnn_class0 / total_frames) * 100
        cnn_class1_percent = (cnn_class1 / total_frames) * 100
        qcnn_class0_percent = (qcnn_class0 / total_frames) * 100
        qcnn_class1_percent = (qcnn_class1 / total_frames) * 100
    else:
        cnn_class0_percent = cnn_class1_percent = qcnn_class0_percent = qcnn_class1_percent = 0

    result = ""
    if true_label is not None:
        cnn_accuracy = (cnn_correct / total_frames) * 100 if total_frames > 0 else 0
        qcnn_accuracy = (qcnn_correct / total_frames) * 100 if total_frames > 0 else 0
        result += f"CNN Model Accuracy: {cnn_accuracy:.2f}%\n"
        result += f"QCNN Model Accuracy: {qcnn_accuracy:.2f}%"
    
    result += f"CNN Model Predictions:\nOriginal: {cnn_class0_percent:.2f}%\nFake: {cnn_class1_percent:.2f}%\n"
    result += f"QCNN Model Predictions:\nOriginal: {qcnn_class0_percent:.2f}%\nFake: {qcnn_class1_percent:.2f}%"
    return result

def predict(video_input):
    if video_input is None:
        return "Please upload a video or select an example."
    if isinstance(video_input, dict):
        video_path = video_input['name']
    elif isinstance(video_input, str):
        video_path = video_input
    else:
        return "Invalid video input."

    # Check if video is an example
    true_label = None
    if video_path in example_videos_dict:
        label = example_videos_dict[video_path]
        if label == 'Original':
            true_label = 0
        else:
            true_label = 1

    result = process_video(video_path, true_label=true_label)
    return result

with gr.Blocks() as demo:
    gr.HTML("<h1 style='text-align: center;'>Quanvolutional Neural Networks for Deepfake Detection</h1>")
    gr.HTML("<h2 style='text-align: center;'>Steven Fernandes, Ph.D.</h2>")

    with gr.Row():
        with gr.Column():
            video_input = gr.Video(label="Upload Video")
            examples_original = gr.Examples(
                label="Training: Original Videos",
                inputs=video_input,
                examples=examples_original,
            )
            examples_deepfake_web = gr.Examples(
                label="Training: Deepfake Videos Generated Using Deepfakesweb",
                inputs=video_input,
                examples=examples_deepfake_web,
            )
            examples_deepfake_roop = gr.Examples(
                label="Testing: Deepfake Videos Generated Using Roop",
                inputs=video_input,
                examples=examples_deepfake_roop,
            )
        with gr.Column():
            output = gr.Textbox(label="Result")
            predict_button = gr.Button("Predict")

    predict_button.click(fn=predict, inputs=video_input, outputs=output)
    demo.launch()