File size: 5,009 Bytes
88d71c2 51a2a1d 88d71c2 2520146 88d71c2 2520146 51a2a1d 2520146 51a2a1d 2520146 51a2a1d 2520146 eda917f 51a2a1d eda917f 51a2a1d 2520146 eda917f 51a2a1d eda917f 88d71c2 51a2a1d eda917f 88d71c2 51a2a1d eda917f 51a2a1d eda917f 51a2a1d eda917f 51a2a1d 2520146 eda917f 2520146 51a2a1d eda917f 2520146 51a2a1d 2520146 eda917f 2520146 51a2a1d 88d71c2 51a2a1d 2520146 eda917f 51a2a1d 2520146 51a2a1d eda917f 51a2a1d eda917f 2520146 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import gradio as gr
import os
import sys
print("Starting AI Image Detector...")
print(f"Working directory: {os.getcwd()}")
print(f"Files in directory: {os.listdir('.')}")
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Define image transformations (same as validation transforms)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def load_model():
print("Creating model architecture...")
# Create model architecture
model = models.efficientnet_v2_s(weights=None)
# Create the same classifier as in training
model.classifier = nn.Sequential(
nn.Linear(model.classifier[1].in_features, 1024),
nn.ReLU(),
nn.Dropout(p=0.3),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(p=0.3),
nn.Linear(512, 2)
)
# Try to load from multiple possible locations
possible_paths = [
"best_model_improved.pth",
"pytorch_model.bin",
"/repository/best_model_improved.pth",
"/repository/pytorch_model.bin",
os.path.join(os.path.dirname(os.path.abspath(__file__)), "best_model_improved.pth"),
os.path.join(os.path.dirname(os.path.abspath(__file__)), "pytorch_model.bin")
]
model_loaded = False
for model_path in possible_paths:
if os.path.exists(model_path):
print(f"Loading model from: {model_path}")
try:
model.load_state_dict(torch.load(model_path, map_location=device))
model_loaded = True
break
except Exception as e:
print(f"Error loading from {model_path}: {e}")
if not model_loaded:
print("WARNING: Could not load model weights. Using untrained model.")
model.to(device)
model.eval()
return model
# Global model variable
model = None
def predict_image(img):
global model
if img is None:
return {"Error": "No image provided"}, "Error: Please upload an image"
try:
# Load model if not already loaded
if model is None:
model = load_model()
# Preprocess the image
img_tensor = transform(img).unsqueeze(0).to(device)
# Make prediction
with torch.no_grad():
outputs = model(img_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
prediction = torch.argmax(probabilities).item()
# Get probability values
real_prob = probabilities[0].item() * 100
ai_prob = probabilities[1].item() * 100
# Create result dictionary
result = {
"Real Image": f"{real_prob:.2f}%",
"AI-Generated": f"{ai_prob:.2f}%"
}
# Determine classification
classification = "Real Image" if prediction == 0 else "AI-Generated Image"
confidence = real_prob if prediction == 0 else ai_prob
confidence_text = f"Confidence: {confidence:.2f}%"
return result, classification + " - " + confidence_text
except Exception as e:
import traceback
print(f"Error during prediction: {e}")
traceback.print_exc()
return {"error": str(e)}, f"Error: {str(e)}"
# Define Gradio interface - simplified for Hugging Face
def create_interface():
with gr.Blocks(title="AI Image Detector", theme=gr.themes.Soft()) as interface:
gr.Markdown("# AI Image Detector")
gr.Markdown("Upload an image to check if it's real or AI-generated")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Image")
analyze_btn = gr.Button("Analyze Image", variant="primary")
with gr.Column():
result_label = gr.Label(label="Prediction Probabilities")
classification = gr.Textbox(label="Classification Result")
# Set up the click event
analyze_btn.click(
fn=predict_image,
inputs=input_image,
outputs=[result_label, classification]
)
gr.Markdown("### How It Works")
gr.Markdown("""
This tool uses a deep learning model trained on thousands of real and AI-generated images.
The model analyzes visual patterns that are typically present in AI-generated images but not in real photographs.
**Note**: While the model is highly accurate, it's not perfect. Some AI-generated images may be classified as real, and vice versa.
""")
return interface
# Launch the interface
if __name__ == "__main__":
interface = create_interface()
interface.launch()
|