yaya36095's picture
Update app.py
88d71c2 verified
raw
history blame
7.17 kB
"""
Hugging Face Spaces inference API for Enhanced AI Image Detector
"""
import os
import sys
import gradio as gr
import numpy as np
from PIL import Image
import cv2
import tempfile
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
# Define the model architecture based on EfficientNetV2-S
class AIDetectorModel(nn.Module):
def __init__(self):
super(AIDetectorModel, self).__init__()
# Load EfficientNetV2-S as base model
self.base_model = models.efficientnet_v2_s(weights=None)
# Replace classifier with custom layers
self.base_model.classifier = nn.Sequential(
nn.Linear(self.base_model.classifier[1].in_features, 1024),
nn.ReLU(),
nn.Dropout(p=0.3),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(p=0.3),
nn.Linear(512, 2) # 2 classes: real or AI-generated
)
def forward(self, x):
return self.base_model(x)
# Custom detector class
class AIDetector:
def __init__(self, model_path='best_model_improved.pth'):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
# Initialize the model
self.model = AIDetectorModel()
# Load the trained weights
try:
# Try to load with strict=True first
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
print(f"Model loaded successfully from {model_path}")
except Exception as e:
print(f"Error with strict loading: {e}")
print("Trying with strict=False...")
# If that fails, try with strict=False
self.model.load_state_dict(torch.load(model_path, map_location=self.device), strict=False)
print("Model loaded with strict=False")
self.model.to(self.device)
self.model.eval() # Set to evaluation mode
# Define image transformations - same as used in training
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Initialize the detector
detector = AIDetector()
def analyze_image(image):
"""
Analyze an image using the Enhanced AI Image Detector with PyTorch model
Args:
image: Image uploaded through Gradio interface
Returns:
Analysis results and visualization
"""
if image is None:
return "Please upload an image", ""
try:
# Convert to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image).convert('RGB')
# Preprocess the image
image_tensor = detector.transform(image).unsqueeze(0).to(detector.device)
# Make prediction
with torch.no_grad():
outputs = detector.model(image_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
# Get the probability of being AI-generated (assuming class 1 is AI-generated)
ai_score = probabilities[0, 1].item()
real_score = probabilities[0, 0].item()
# Determine if the image is AI-generated
is_ai_generated = ai_score > 0.5
# Create result message with emoji and confidence
if is_ai_generated:
message = f"πŸ€– This image is likely AI-generated (Confidence: {ai_score:.2f})"
else:
message = f"πŸ“· This image is likely authentic (Confidence: {real_score:.2f})"
# Create detailed analysis with markdown formatting
detailed_analysis = f"""
### Detailed Analysis:
| Property | Value |
|----------|-------|
| AI Score | {ai_score:.4f} |
| Real Score | {real_score:.4f} |
| Prediction | {'AI-generated' if is_ai_generated else 'Real'} |
| Model | Enhanced AI Image Detector |
| Architecture | EfficientNetV2-S |
"""
return message, detailed_analysis
except Exception as e:
error_message = f"Error analyzing image: {str(e)}"
return error_message, ""
# Create the Gradio interface with improved UI
with gr.Blocks(title="Enhanced AI Image Detector", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ” Enhanced AI Image Detector")
gr.Markdown("Upload an image to determine if it's real or AI-generated.")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload Image", elem_id="input-image")
analyze_button = gr.Button("πŸ” Analyze Image", variant="primary")
# Add example images for quick testing
gr.Examples(
examples=[
"https://storage.googleapis.com/kagglesdsdata/datasets/3676638/6262341/real_and_fake_face/training_real/real_00001.jpg",
"https://storage.googleapis.com/kagglesdsdata/datasets/3676638/6262341/real_and_fake_face/training_fake/fake_00001.jpg"
],
inputs=input_image,
label="Example Images"
)
with gr.Column(scale=1):
result_text = gr.Textbox(label="Result", elem_id="result-text")
detailed_output = gr.Markdown(label="Detailed Analysis", elem_id="detailed-output")
# Add loading state to improve user experience
analyze_button.click(
fn=analyze_image,
inputs=[input_image],
outputs=[result_text, detailed_output],
api_name="analyze"
)
gr.Markdown("""
## How it works
This model uses a trained PyTorch neural network (EfficientNetV2-S) to detect AI-generated images. The model has been trained on a large dataset of real and AI-generated images to learn the subtle differences between them.
The model can detect patterns that are often invisible to the human eye, including:
1. **Noise and artifact patterns** specific to AI generation methods
2. **Texture inconsistencies** that appear in AI-generated content
3. **Color and lighting anomalies** common in synthetic images
4. **Structural patterns** that differ from natural photographs
## Limitations
- The model may struggle with highly realistic AI-generated images from newer generation models
- Some real images with unusual characteristics may be misclassified
- Performance depends on image quality and resolution
- The model works best with images similar to those in its training dataset
""")
# Launch the app with improved configuration
if __name__ == "__main__":
demo.launch(
share=True, # Create a public link for sharing
enable_queue=True, # Enable queue for handling multiple requests
show_error=True, # Show detailed error messages
favicon_path="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" # HF favicon
)