LPX55
commited on
Commit
·
2484926
1
Parent(s):
4378fd8
Enhance prediction output and result display in app.py
Browse files- Modify prediction methods to generate structured output lists for each model
- Add model-specific output tracking with confidence scores and classification labels
- Update HTML results display to include model badges
- Adjust Gradio interface layout for better visualization
- Improve error handling and logging in prediction functions
- .gitignore +1 -0
- app.py +42 -18
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.env
|
app.py
CHANGED
|
@@ -67,7 +67,8 @@ def predict_image(img, confidence_threshold):
|
|
| 67 |
try:
|
| 68 |
prediction_1 = clf_1(img_pil)
|
| 69 |
result_1 = {pred['label']: pred['score'] for pred in prediction_1}
|
| 70 |
-
|
|
|
|
| 71 |
# Ensure the result dictionary contains all class names
|
| 72 |
for class_name in class_names_1:
|
| 73 |
if class_name not in result_1:
|
|
@@ -75,18 +76,23 @@ def predict_image(img, confidence_threshold):
|
|
| 75 |
# Check if either class meets the confidence threshold
|
| 76 |
if result_1['artificial'] >= confidence_threshold:
|
| 77 |
label_1 = f"AI, Confidence: {result_1['artificial']:.4f}"
|
|
|
|
| 78 |
elif result_1['real'] >= confidence_threshold:
|
| 79 |
label_1 = f"Real, Confidence: {result_1['real']:.4f}"
|
|
|
|
| 80 |
else:
|
| 81 |
label_1 = "Uncertain Classification"
|
|
|
|
|
|
|
| 82 |
except Exception as e:
|
| 83 |
label_1 = f"Error: {str(e)}"
|
| 84 |
-
|
| 85 |
# Predict using the second model
|
| 86 |
try:
|
| 87 |
prediction_2 = clf_2(img_pil)
|
| 88 |
result_2 = {pred['label']: pred['score'] for pred in prediction_2}
|
| 89 |
-
|
|
|
|
| 90 |
# Ensure the result dictionary contains all class names
|
| 91 |
for class_name in class_names_2:
|
| 92 |
if class_name not in result_2:
|
|
@@ -94,10 +100,13 @@ def predict_image(img, confidence_threshold):
|
|
| 94 |
# Check if either class meets the confidence threshold
|
| 95 |
if result_2['AI Image'] >= confidence_threshold:
|
| 96 |
label_2 = f"AI, Confidence: {result_2['AI Image']:.4f}"
|
|
|
|
| 97 |
elif result_2['Real Image'] >= confidence_threshold:
|
| 98 |
label_2 = f"Real, Confidence: {result_2['Real Image']:.4f}"
|
|
|
|
| 99 |
else:
|
| 100 |
label_2 = "Uncertain Classification"
|
|
|
|
| 101 |
except Exception as e:
|
| 102 |
label_2 = f"Error: {str(e)}"
|
| 103 |
|
|
@@ -109,10 +118,11 @@ def predict_image(img, confidence_threshold):
|
|
| 109 |
logits_3 = outputs_3.logits
|
| 110 |
probabilities_3 = softmax(logits_3.cpu().numpy()[0])
|
| 111 |
result_3 = {
|
| 112 |
-
labels_3[
|
| 113 |
-
labels_3[
|
| 114 |
}
|
| 115 |
-
|
|
|
|
| 116 |
# Ensure the result dictionary contains all class names
|
| 117 |
for class_name in labels_3:
|
| 118 |
if class_name not in result_3:
|
|
@@ -120,10 +130,13 @@ def predict_image(img, confidence_threshold):
|
|
| 120 |
# Check if either class meets the confidence threshold
|
| 121 |
if result_3['AI'] >= confidence_threshold:
|
| 122 |
label_3 = f"AI, Confidence: {result_3['AI']:.4f}"
|
|
|
|
| 123 |
elif result_3['Real'] >= confidence_threshold:
|
| 124 |
label_3 = f"Real, Confidence: {result_3['Real']:.4f}"
|
|
|
|
| 125 |
else:
|
| 126 |
label_3 = "Uncertain Classification"
|
|
|
|
| 127 |
except Exception as e:
|
| 128 |
label_3 = f"Error: {str(e)}"
|
| 129 |
|
|
@@ -135,9 +148,10 @@ def predict_image(img, confidence_threshold):
|
|
| 135 |
logits_4 = outputs_4.logits
|
| 136 |
probabilities_4 = softmax(logits_4.cpu().numpy()[0])
|
| 137 |
result_4 = {
|
| 138 |
-
labels_4[
|
| 139 |
-
labels_4[
|
| 140 |
}
|
|
|
|
| 141 |
print(result_4)
|
| 142 |
# Ensure the result dictionary contains all class names
|
| 143 |
for class_name in labels_4:
|
|
@@ -146,19 +160,27 @@ def predict_image(img, confidence_threshold):
|
|
| 146 |
# Check if either class meets the confidence threshold
|
| 147 |
if result_4['AI'] >= confidence_threshold:
|
| 148 |
label_4 = f"AI, Confidence: {result_4['AI']:.4f}"
|
|
|
|
| 149 |
elif result_4['Real'] >= confidence_threshold:
|
| 150 |
label_4 = f"Real, Confidence: {result_4['Real']:.4f}"
|
|
|
|
| 151 |
else:
|
| 152 |
label_4 = "Uncertain Classification"
|
|
|
|
| 153 |
except Exception as e:
|
| 154 |
label_4 = f"Error: {str(e)}"
|
| 155 |
|
| 156 |
try:
|
|
|
|
| 157 |
img_bytes = convert_pil_to_bytes(img_pil)
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
| 160 |
print(response5)
|
| 161 |
label_5 = f"Result: {response5}"
|
|
|
|
| 162 |
except Exception as e:
|
| 163 |
label_5 = f"Error: {str(e)}"
|
| 164 |
|
|
@@ -170,32 +192,34 @@ def predict_image(img, confidence_threshold):
|
|
| 170 |
"Swin/SDXL-FLUX": label_4,
|
| 171 |
"GOAT": label_5
|
| 172 |
}
|
| 173 |
-
|
|
|
|
| 174 |
|
| 175 |
# Define a function to generate the HTML content
|
| 176 |
def generate_results_html(results):
|
|
|
|
| 177 |
html_content = f"""
|
| 178 |
<link href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" rel="stylesheet">
|
| 179 |
<div class="container">
|
| 180 |
<div class="row mt-4">
|
| 181 |
<div class="col">
|
| 182 |
-
<h5>SwinV2/detect</h5>
|
| 183 |
<p>{results.get("SwinV2/detect", "N/A")}</p>
|
| 184 |
</div>
|
| 185 |
<div class="col">
|
| 186 |
-
<h5>ViT/AI-vs-Real</h5>
|
| 187 |
<p>{results.get("ViT/AI-vs-Real", "N/A")}</p>
|
| 188 |
</div>
|
| 189 |
<div class="col">
|
| 190 |
-
<h5>Swin/SDXL</h5>
|
| 191 |
<p>{results.get("Swin/SDXL", "N/A")}</p>
|
| 192 |
</div>
|
| 193 |
<div class="col">
|
| 194 |
-
<h5>Swin/SDXL-FLUX</h5>
|
| 195 |
<p>{results.get("Swin/SDXL-FLUX", "N/A")}</p>
|
| 196 |
</div>
|
| 197 |
<div class="col">
|
| 198 |
-
<h5>GOAT</h5>
|
| 199 |
<p>{results.get("GOAT", "N/A")}</p>
|
| 200 |
</div>
|
| 201 |
</div>
|
|
@@ -214,11 +238,11 @@ with gr.Blocks() as iface:
|
|
| 214 |
gr.Markdown("# AI Generated Image Classification")
|
| 215 |
|
| 216 |
with gr.Row():
|
| 217 |
-
with gr.Column():
|
| 218 |
image_input = gr.Image(label="Upload Image to Analyze", sources=['upload'], type='pil')
|
| 219 |
confidence_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Confidence Threshold")
|
| 220 |
inputs = [image_input, confidence_slider]
|
| 221 |
-
with gr.Column():
|
| 222 |
image_output = gr.Image(label="Processed Image")
|
| 223 |
# Custom HTML component to display results in 5 columns
|
| 224 |
results_html = gr.HTML(label="Model Predictions")
|
|
|
|
| 67 |
try:
|
| 68 |
prediction_1 = clf_1(img_pil)
|
| 69 |
result_1 = {pred['label']: pred['score'] for pred in prediction_1}
|
| 70 |
+
result_1output = [1, result_1['real'], result_1['artificial']]
|
| 71 |
+
print(result_1output)
|
| 72 |
# Ensure the result dictionary contains all class names
|
| 73 |
for class_name in class_names_1:
|
| 74 |
if class_name not in result_1:
|
|
|
|
| 76 |
# Check if either class meets the confidence threshold
|
| 77 |
if result_1['artificial'] >= confidence_threshold:
|
| 78 |
label_1 = f"AI, Confidence: {result_1['artificial']:.4f}"
|
| 79 |
+
result_1output += ['AI']
|
| 80 |
elif result_1['real'] >= confidence_threshold:
|
| 81 |
label_1 = f"Real, Confidence: {result_1['real']:.4f}"
|
| 82 |
+
result_1output += ['REAL']
|
| 83 |
else:
|
| 84 |
label_1 = "Uncertain Classification"
|
| 85 |
+
result_1output += ['UNCERTAIN']
|
| 86 |
+
|
| 87 |
except Exception as e:
|
| 88 |
label_1 = f"Error: {str(e)}"
|
| 89 |
+
print(result_1output)
|
| 90 |
# Predict using the second model
|
| 91 |
try:
|
| 92 |
prediction_2 = clf_2(img_pil)
|
| 93 |
result_2 = {pred['label']: pred['score'] for pred in prediction_2}
|
| 94 |
+
result_2output = [2, result_2['Real Image'], result_2['AI Image']]
|
| 95 |
+
print(result_2output)
|
| 96 |
# Ensure the result dictionary contains all class names
|
| 97 |
for class_name in class_names_2:
|
| 98 |
if class_name not in result_2:
|
|
|
|
| 100 |
# Check if either class meets the confidence threshold
|
| 101 |
if result_2['AI Image'] >= confidence_threshold:
|
| 102 |
label_2 = f"AI, Confidence: {result_2['AI Image']:.4f}"
|
| 103 |
+
result_2output += ['AI']
|
| 104 |
elif result_2['Real Image'] >= confidence_threshold:
|
| 105 |
label_2 = f"Real, Confidence: {result_2['Real Image']:.4f}"
|
| 106 |
+
result_2output += ['REAL']
|
| 107 |
else:
|
| 108 |
label_2 = "Uncertain Classification"
|
| 109 |
+
result_2output += ['UNCERTAIN']
|
| 110 |
except Exception as e:
|
| 111 |
label_2 = f"Error: {str(e)}"
|
| 112 |
|
|
|
|
| 118 |
logits_3 = outputs_3.logits
|
| 119 |
probabilities_3 = softmax(logits_3.cpu().numpy()[0])
|
| 120 |
result_3 = {
|
| 121 |
+
labels_3[1]: float(probabilities_3[1]), # Real
|
| 122 |
+
labels_3[0]: float(probabilities_3[0]) # AI
|
| 123 |
}
|
| 124 |
+
result_3output = [3, float(probabilities_3[1]), float(probabilities_3[0])]
|
| 125 |
+
print(result_3output)
|
| 126 |
# Ensure the result dictionary contains all class names
|
| 127 |
for class_name in labels_3:
|
| 128 |
if class_name not in result_3:
|
|
|
|
| 130 |
# Check if either class meets the confidence threshold
|
| 131 |
if result_3['AI'] >= confidence_threshold:
|
| 132 |
label_3 = f"AI, Confidence: {result_3['AI']:.4f}"
|
| 133 |
+
result_3output += ['AI']
|
| 134 |
elif result_3['Real'] >= confidence_threshold:
|
| 135 |
label_3 = f"Real, Confidence: {result_3['Real']:.4f}"
|
| 136 |
+
result_3output += ['REAL']
|
| 137 |
else:
|
| 138 |
label_3 = "Uncertain Classification"
|
| 139 |
+
result_3output += ['UNCERTAIN']
|
| 140 |
except Exception as e:
|
| 141 |
label_3 = f"Error: {str(e)}"
|
| 142 |
|
|
|
|
| 148 |
logits_4 = outputs_4.logits
|
| 149 |
probabilities_4 = softmax(logits_4.cpu().numpy()[0])
|
| 150 |
result_4 = {
|
| 151 |
+
labels_4[1]: float(probabilities_4[1]), # Real
|
| 152 |
+
labels_4[0]: float(probabilities_4[0]) # AI
|
| 153 |
}
|
| 154 |
+
result_4output = [4, float(probabilities_4[1]), float(probabilities_4[0])]
|
| 155 |
print(result_4)
|
| 156 |
# Ensure the result dictionary contains all class names
|
| 157 |
for class_name in labels_4:
|
|
|
|
| 160 |
# Check if either class meets the confidence threshold
|
| 161 |
if result_4['AI'] >= confidence_threshold:
|
| 162 |
label_4 = f"AI, Confidence: {result_4['AI']:.4f}"
|
| 163 |
+
result_4output += ['AI']
|
| 164 |
elif result_4['Real'] >= confidence_threshold:
|
| 165 |
label_4 = f"Real, Confidence: {result_4['Real']:.4f}"
|
| 166 |
+
result_4output += ['REAL']
|
| 167 |
else:
|
| 168 |
label_4 = "Uncertain Classification"
|
| 169 |
+
result_4output += ['UNCERTAIN']
|
| 170 |
except Exception as e:
|
| 171 |
label_4 = f"Error: {str(e)}"
|
| 172 |
|
| 173 |
try:
|
| 174 |
+
result_5output = [5, 0.0, 0.0, 'MAINTENANCE']
|
| 175 |
img_bytes = convert_pil_to_bytes(img_pil)
|
| 176 |
+
# print(img)
|
| 177 |
+
# print(img_bytes)
|
| 178 |
+
response5_raw = call_inference(img)
|
| 179 |
+
print(response5_raw)
|
| 180 |
+
response5 = response5_raw
|
| 181 |
print(response5)
|
| 182 |
label_5 = f"Result: {response5}"
|
| 183 |
+
|
| 184 |
except Exception as e:
|
| 185 |
label_5 = f"Error: {str(e)}"
|
| 186 |
|
|
|
|
| 192 |
"Swin/SDXL-FLUX": label_4,
|
| 193 |
"GOAT": label_5
|
| 194 |
}
|
| 195 |
+
combined_outputs = [ result_1output, result_2output, result_3output, result_4output, result_5output ]
|
| 196 |
+
return img_pil, combined_outputs
|
| 197 |
|
| 198 |
# Define a function to generate the HTML content
|
| 199 |
def generate_results_html(results):
|
| 200 |
+
print(results)
|
| 201 |
html_content = f"""
|
| 202 |
<link href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" rel="stylesheet">
|
| 203 |
<div class="container">
|
| 204 |
<div class="row mt-4">
|
| 205 |
<div class="col">
|
| 206 |
+
<h5>SwinV2/detect <span class="badge badge-secondary">M1</span></h5>
|
| 207 |
<p>{results.get("SwinV2/detect", "N/A")}</p>
|
| 208 |
</div>
|
| 209 |
<div class="col">
|
| 210 |
+
<h5>ViT/AI-vs-Real <span class="badge badge-secondary">M2</span></h5>
|
| 211 |
<p>{results.get("ViT/AI-vs-Real", "N/A")}</p>
|
| 212 |
</div>
|
| 213 |
<div class="col">
|
| 214 |
+
<h5>Swin/SDXL <span class="badge badge-secondary">M3</span></h5>
|
| 215 |
<p>{results.get("Swin/SDXL", "N/A")}</p>
|
| 216 |
</div>
|
| 217 |
<div class="col">
|
| 218 |
+
<h5>Swin/SDXL-FLUX <span class="badge badge-secondary">M4</span></h5>
|
| 219 |
<p>{results.get("Swin/SDXL-FLUX", "N/A")}</p>
|
| 220 |
</div>
|
| 221 |
<div class="col">
|
| 222 |
+
<h5>GOAT <span class="badge badge-secondary">M5</span></h5>
|
| 223 |
<p>{results.get("GOAT", "N/A")}</p>
|
| 224 |
</div>
|
| 225 |
</div>
|
|
|
|
| 238 |
gr.Markdown("# AI Generated Image Classification")
|
| 239 |
|
| 240 |
with gr.Row():
|
| 241 |
+
with gr.Column(scale=2):
|
| 242 |
image_input = gr.Image(label="Upload Image to Analyze", sources=['upload'], type='pil')
|
| 243 |
confidence_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Confidence Threshold")
|
| 244 |
inputs = [image_input, confidence_slider]
|
| 245 |
+
with gr.Column(scale=3):
|
| 246 |
image_output = gr.Image(label="Processed Image")
|
| 247 |
# Custom HTML component to display results in 5 columns
|
| 248 |
results_html = gr.HTML(label="Model Predictions")
|