|
import gradio as gr |
|
from PIL import Image |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as transforms |
|
import torchvision.models as models |
|
import os |
|
import torch |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
main_model = models.mobilenet_v3_large(weights=None) |
|
|
|
|
|
|
|
|
|
num_ftrs = main_model.classifier[3].in_features |
|
main_model.classifier[3] = nn.Linear(num_ftrs, 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
main_model.load_state_dict(torch.load('best_model3_mobilenetv3_large.pth', map_location=device, weights_only=True)) |
|
main_model = main_model.to(device) |
|
main_model.eval() |
|
|
|
|
|
classes_name = ['AI-generated Image', 'Real Image'] |
|
|
|
def convert_to_rgb(image): |
|
""" |
|
Converts 'P' mode images with transparency to 'RGBA', and then to 'RGB'. |
|
This is to avoid transparency issues during model training. |
|
""" |
|
if image.mode in ('P', 'RGBA'): |
|
return image.convert('RGB') |
|
return image |
|
|
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Lambda(convert_to_rgb), |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
|
|
def classify_image(image): |
|
|
|
image = Image.fromarray(image) |
|
|
|
|
|
input_image = preprocess(image).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
output = main_model(input_image) |
|
probabilities = torch.nn.functional.softmax(output[0], dim=0) |
|
confidence, predicted_class = torch.max(probabilities, 0) |
|
|
|
|
|
main_prediction = classes_name[predicted_class] |
|
main_confidence = confidence.item() |
|
|
|
return f"Image is : {main_prediction} (Confidence: {main_confidence:.4f})" |
|
|
|
|
|
image_input = gr.Image(image_mode="RGB") |
|
output_text = gr.Textbox() |
|
|
|
|
|
|
|
|
|
|
|
|
|
gr.Interface( |
|
fn=classify_image, |
|
inputs=image_input, |
|
outputs=[output_text], |
|
title="Detect AI-generated Image", |
|
description=( |
|
"Upload an art image From 6 websites, collecting data from this to detect if it's AI-generated or a real image. take care image jpg or png only.\n\n" |
|
"### Main Dataset Used:\n" |
|
"- [AI-generated Images vs Real Images (Kaggle)](https://www.kaggle.com/datasets/tristanzhang32/ai-generated-images-vs-real-images)\n\n" |
|
"**Fake Images Collected From:**\n" |
|
"- 10,000 from [Stable Diffusion (OpenArt AI)](https://www.openart.ai)\n" |
|
"- 10,000 from [MidJourney (Imagine.Art)](https://www.imagine.art)\n" |
|
"- 10,000 from [DALL·E (OpenAI)](https://openai.com/dall-e-2)\n\n" |
|
"**Real Images Collected From:**\n" |
|
"- 7,500 from [WikiArt](https://www.wikiart.org)\n" |
|
"- 22,500 from [Pexels](https://www.pexels.com) and [Unsplash but take care image jpg or png only ](https://unsplash.com)\n" |
|
|
|
), |
|
theme="default" |
|
).launch() |