|
import torch |
|
import torch.nn as nn |
|
from torchvision import models, transforms |
|
from PIL import Image |
|
import gradio as gr |
|
import os |
|
import sys |
|
|
|
print("Starting AI Image Detector...") |
|
print(f"Working directory: {os.getcwd()}") |
|
print(f"Files in directory: {os.listdir('.')}") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
|
|
def load_model(): |
|
print("Creating model architecture...") |
|
|
|
model = models.efficientnet_v2_s(weights=None) |
|
|
|
|
|
model.classifier = nn.Sequential( |
|
nn.Linear(model.classifier[1].in_features, 1024), |
|
nn.ReLU(), |
|
nn.Dropout(p=0.3), |
|
nn.Linear(1024, 512), |
|
nn.ReLU(), |
|
nn.Dropout(p=0.3), |
|
nn.Linear(512, 2) |
|
) |
|
|
|
|
|
possible_paths = [ |
|
"best_model_improved.pth", |
|
"pytorch_model.bin", |
|
"/repository/best_model_improved.pth", |
|
"/repository/pytorch_model.bin", |
|
os.path.join(os.path.dirname(os.path.abspath(__file__)), "best_model_improved.pth"), |
|
os.path.join(os.path.dirname(os.path.abspath(__file__)), "pytorch_model.bin") |
|
] |
|
|
|
model_loaded = False |
|
for model_path in possible_paths: |
|
if os.path.exists(model_path): |
|
print(f"Loading model from: {model_path}") |
|
try: |
|
model.load_state_dict(torch.load(model_path, map_location=device)) |
|
model_loaded = True |
|
break |
|
except Exception as e: |
|
print(f"Error loading from {model_path}: {e}") |
|
|
|
if not model_loaded: |
|
print("WARNING: Could not load model weights. Using untrained model.") |
|
|
|
model.to(device) |
|
model.eval() |
|
return model |
|
|
|
|
|
model = None |
|
|
|
def predict_image(img): |
|
global model |
|
|
|
if img is None: |
|
return {"Error": "No image provided"}, "Error: Please upload an image" |
|
|
|
try: |
|
|
|
if model is None: |
|
model = load_model() |
|
|
|
|
|
img_tensor = transform(img).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(img_tensor) |
|
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] |
|
prediction = torch.argmax(probabilities).item() |
|
|
|
|
|
real_prob = probabilities[0].item() * 100 |
|
ai_prob = probabilities[1].item() * 100 |
|
|
|
|
|
result = { |
|
"Real Image": f"{real_prob:.2f}%", |
|
"AI-Generated": f"{ai_prob:.2f}%" |
|
} |
|
|
|
|
|
classification = "Real Image" if prediction == 0 else "AI-Generated Image" |
|
confidence = real_prob if prediction == 0 else ai_prob |
|
confidence_text = f"Confidence: {confidence:.2f}%" |
|
|
|
return result, classification + " - " + confidence_text |
|
|
|
except Exception as e: |
|
import traceback |
|
print(f"Error during prediction: {e}") |
|
traceback.print_exc() |
|
return {"error": str(e)}, f"Error: {str(e)}" |
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks(title="AI Image Detector", theme=gr.themes.Soft()) as interface: |
|
gr.Markdown("# AI Image Detector") |
|
gr.Markdown("Upload an image to check if it's real or AI-generated") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(type="pil", label="Upload Image") |
|
analyze_btn = gr.Button("Analyze Image", variant="primary") |
|
|
|
with gr.Column(): |
|
result_label = gr.Label(label="Prediction Probabilities") |
|
classification = gr.Textbox(label="Classification Result") |
|
|
|
|
|
analyze_btn.click( |
|
fn=predict_image, |
|
inputs=input_image, |
|
outputs=[result_label, classification] |
|
) |
|
|
|
gr.Markdown("### How It Works") |
|
gr.Markdown(""" |
|
This tool uses a deep learning model trained on thousands of real and AI-generated images. |
|
The model analyzes visual patterns that are typically present in AI-generated images but not in real photographs. |
|
|
|
**Note**: While the model is highly accurate, it's not perfect. Some AI-generated images may be classified as real, and vice versa. |
|
""") |
|
|
|
return interface |
|
|
|
|
|
if __name__ == "__main__": |
|
interface = create_interface() |
|
interface.launch() |
|
|