Update app.py
Browse files
app.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
import spaces
|
| 2 |
import gradio as gr
|
| 3 |
-
from transformers import pipeline, AutoImageProcessor, Swinv2ForImageClassification
|
| 4 |
from torchvision import transforms
|
| 5 |
import torch
|
| 6 |
from PIL import Image
|
|
|
|
| 7 |
import warnings
|
|
|
|
| 8 |
|
| 9 |
# Suppress warnings
|
| 10 |
warnings.filterwarnings("ignore", category=UserWarning, message="Using a slow image processor as `use_fast` is unset")
|
|
@@ -20,13 +22,24 @@ clf_1 = pipeline(model=model_1, task="image-classification", image_processor=ima
|
|
| 20 |
|
| 21 |
# Load the second model
|
| 22 |
model_2_path = "Heem2/AI-vs-Real-Image-Detection"
|
| 23 |
-
clf_2 = pipeline("image-classification", model=model_2_path
|
| 24 |
|
| 25 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
class_names_1 = ['artificial', 'real']
|
| 27 |
-
class_names_2 = ['AI Image', 'Real Image']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
@spaces.GPU(duration=
|
| 30 |
def predict_image(img, confidence_threshold):
|
| 31 |
# Ensure the image is a PIL Image
|
| 32 |
if not isinstance(img, Image.Image):
|
|
@@ -81,10 +94,56 @@ def predict_image(img, confidence_threshold):
|
|
| 81 |
except Exception as e:
|
| 82 |
label_2 = f"Error: {str(e)}"
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
# Combine results
|
| 85 |
combined_results = {
|
| 86 |
"SwinV2": label_1,
|
| 87 |
-
"AI-vs-Real-Image-Detection": label_2
|
|
|
|
|
|
|
| 88 |
}
|
| 89 |
|
| 90 |
return combined_results
|
|
|
|
| 1 |
import spaces
|
| 2 |
import gradio as gr
|
| 3 |
+
from transformers import pipeline, AutoImageProcessor, Swinv2ForImageClassification, AutoFeatureExtractor, AutoModelForImageClassification
|
| 4 |
from torchvision import transforms
|
| 5 |
import torch
|
| 6 |
from PIL import Image
|
| 7 |
+
import pandas as pd
|
| 8 |
import warnings
|
| 9 |
+
import math
|
| 10 |
|
| 11 |
# Suppress warnings
|
| 12 |
warnings.filterwarnings("ignore", category=UserWarning, message="Using a slow image processor as `use_fast` is unset")
|
|
|
|
| 22 |
|
| 23 |
# Load the second model
|
| 24 |
model_2_path = "Heem2/AI-vs-Real-Image-Detection"
|
| 25 |
+
clf_2 = pipeline("image-classification", model=model_2_path)
|
| 26 |
|
| 27 |
+
# Load additional models
|
| 28 |
+
models = ["Organika/sdxl-detector", "cmckinle/sdxl-flux-detector"]
|
| 29 |
+
pipe0 = pipeline("image-classification", model=models[0])
|
| 30 |
+
pipe1 = pipeline("image-classification", model=models[1])
|
| 31 |
+
|
| 32 |
+
# Define class names for all models
|
| 33 |
class_names_1 = ['artificial', 'real']
|
| 34 |
+
class_names_2 = ['AI Image', 'Real Image']
|
| 35 |
+
class_names_3 = ['AI', 'Real']
|
| 36 |
+
class_names_4 = ['AI', 'Real']
|
| 37 |
+
|
| 38 |
+
def softmax(vector):
|
| 39 |
+
e = math.exp(vector - vector.max()) # for numerical stability
|
| 40 |
+
return e / e.sum()
|
| 41 |
|
| 42 |
+
@spaces.GPU(duration=10)
|
| 43 |
def predict_image(img, confidence_threshold):
|
| 44 |
# Ensure the image is a PIL Image
|
| 45 |
if not isinstance(img, Image.Image):
|
|
|
|
| 94 |
except Exception as e:
|
| 95 |
label_2 = f"Error: {str(e)}"
|
| 96 |
|
| 97 |
+
# Predict using the third model
|
| 98 |
+
try:
|
| 99 |
+
prediction_3 = pipe0(img_pil)
|
| 100 |
+
result_3 = {}
|
| 101 |
+
for idx, result in enumerate(prediction_3):
|
| 102 |
+
result_3[class_names_3[idx]] = float(result['score'])
|
| 103 |
+
|
| 104 |
+
# Ensure the result dictionary contains all class names
|
| 105 |
+
for class_name in class_names_3:
|
| 106 |
+
if class_name not in result_3:
|
| 107 |
+
result_3[class_name] = 0.0
|
| 108 |
+
|
| 109 |
+
# Check if either class meets the confidence threshold
|
| 110 |
+
if result_3['AI'] >= confidence_threshold:
|
| 111 |
+
label_3 = f"Label: AI, Confidence: {result_3['AI']:.4f}"
|
| 112 |
+
elif result_3['Real'] >= confidence_threshold:
|
| 113 |
+
label_3 = f"Label: Real, Confidence: {result_3['Real']:.4f}"
|
| 114 |
+
else:
|
| 115 |
+
label_3 = "Uncertain Classification"
|
| 116 |
+
except Exception as e:
|
| 117 |
+
label_3 = f"Error: {str(e)}"
|
| 118 |
+
|
| 119 |
+
# Predict using the fourth model
|
| 120 |
+
try:
|
| 121 |
+
prediction_4 = pipe1(img_pil)
|
| 122 |
+
result_4 = {}
|
| 123 |
+
for idx, result in enumerate(prediction_4):
|
| 124 |
+
result_4[class_names_4[idx]] = float(result['score'])
|
| 125 |
+
|
| 126 |
+
# Ensure the result dictionary contains all class names
|
| 127 |
+
for class_name in class_names_4:
|
| 128 |
+
if class_name not in result_4:
|
| 129 |
+
result_4[class_name] = 0.0
|
| 130 |
+
|
| 131 |
+
# Check if either class meets the confidence threshold
|
| 132 |
+
if result_4['AI'] >= confidence_threshold:
|
| 133 |
+
label_4 = f"Label: AI, Confidence: {result_4['AI']:.4f}"
|
| 134 |
+
elif result_4['Real'] >= confidence_threshold:
|
| 135 |
+
label_4 = f"Label: Real, Confidence: {result_4['Real']:.4f}"
|
| 136 |
+
else:
|
| 137 |
+
label_4 = "Uncertain Classification"
|
| 138 |
+
except Exception as e:
|
| 139 |
+
label_4 = f"Error: {str(e)}"
|
| 140 |
+
|
| 141 |
# Combine results
|
| 142 |
combined_results = {
|
| 143 |
"SwinV2": label_1,
|
| 144 |
+
"AI-vs-Real-Image-Detection": label_2,
|
| 145 |
+
"Organika/sdxl-detector": label_3,
|
| 146 |
+
"cmckinle/sdxl-flux-detector": label_4
|
| 147 |
}
|
| 148 |
|
| 149 |
return combined_results
|