yaya36095's picture
Update app.py
2520146 verified
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import gradio as gr
import os
import sys
print("Starting AI Image Detector...")
print(f"Working directory: {os.getcwd()}")
print(f"Files in directory: {os.listdir('.')}")
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Define image transformations (same as validation transforms)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def load_model():
print("Creating model architecture...")
# Create model architecture
model = models.efficientnet_v2_s(weights=None)
# Create the same classifier as in training
model.classifier = nn.Sequential(
nn.Linear(model.classifier[1].in_features, 1024),
nn.ReLU(),
nn.Dropout(p=0.3),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(p=0.3),
nn.Linear(512, 2)
)
# Try to load from multiple possible locations
possible_paths = [
"best_model_improved.pth",
"pytorch_model.bin",
"/repository/best_model_improved.pth",
"/repository/pytorch_model.bin",
os.path.join(os.path.dirname(os.path.abspath(__file__)), "best_model_improved.pth"),
os.path.join(os.path.dirname(os.path.abspath(__file__)), "pytorch_model.bin")
]
model_loaded = False
for model_path in possible_paths:
if os.path.exists(model_path):
print(f"Loading model from: {model_path}")
try:
model.load_state_dict(torch.load(model_path, map_location=device))
model_loaded = True
break
except Exception as e:
print(f"Error loading from {model_path}: {e}")
if not model_loaded:
print("WARNING: Could not load model weights. Using untrained model.")
model.to(device)
model.eval()
return model
# Global model variable
model = None
def predict_image(img):
global model
if img is None:
return {"Error": "No image provided"}, "Error: Please upload an image"
try:
# Load model if not already loaded
if model is None:
model = load_model()
# Preprocess the image
img_tensor = transform(img).unsqueeze(0).to(device)
# Make prediction
with torch.no_grad():
outputs = model(img_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
prediction = torch.argmax(probabilities).item()
# Get probability values
real_prob = probabilities[0].item() * 100
ai_prob = probabilities[1].item() * 100
# Create result dictionary
result = {
"Real Image": f"{real_prob:.2f}%",
"AI-Generated": f"{ai_prob:.2f}%"
}
# Determine classification
classification = "Real Image" if prediction == 0 else "AI-Generated Image"
confidence = real_prob if prediction == 0 else ai_prob
confidence_text = f"Confidence: {confidence:.2f}%"
return result, classification + " - " + confidence_text
except Exception as e:
import traceback
print(f"Error during prediction: {e}")
traceback.print_exc()
return {"error": str(e)}, f"Error: {str(e)}"
# Define Gradio interface - simplified for Hugging Face
def create_interface():
with gr.Blocks(title="AI Image Detector", theme=gr.themes.Soft()) as interface:
gr.Markdown("# AI Image Detector")
gr.Markdown("Upload an image to check if it's real or AI-generated")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Image")
analyze_btn = gr.Button("Analyze Image", variant="primary")
with gr.Column():
result_label = gr.Label(label="Prediction Probabilities")
classification = gr.Textbox(label="Classification Result")
# Set up the click event
analyze_btn.click(
fn=predict_image,
inputs=input_image,
outputs=[result_label, classification]
)
gr.Markdown("### How It Works")
gr.Markdown("""
This tool uses a deep learning model trained on thousands of real and AI-generated images.
The model analyzes visual patterns that are typically present in AI-generated images but not in real photographs.
**Note**: While the model is highly accurate, it's not perfect. Some AI-generated images may be classified as real, and vice versa.
""")
return interface
# Launch the interface
if __name__ == "__main__":
interface = create_interface()
interface.launch()