Update app.py
Browse files
app.py
CHANGED
|
@@ -7,6 +7,10 @@ from PIL import Image
|
|
| 7 |
import numpy as np
|
| 8 |
from utils.goat import call_inference
|
| 9 |
import io
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# Ensure using GPU if available
|
| 12 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
@@ -168,6 +172,43 @@ def predict_image(img, confidence_threshold):
|
|
| 168 |
}
|
| 169 |
return img_pil, combined_results
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
# Define the Gradio interface
|
| 172 |
with gr.Blocks() as iface:
|
| 173 |
gr.Markdown("# AI Generated Image Classification")
|
|
@@ -185,42 +226,5 @@ with gr.Blocks() as iface:
|
|
| 185 |
|
| 186 |
gr.Button("Predict").click(fn=predict_image_with_html, inputs=inputs, outputs=outputs)
|
| 187 |
|
| 188 |
-
# Define a function to generate the HTML content
|
| 189 |
-
def generate_results_html(results):
|
| 190 |
-
html_content = f"""
|
| 191 |
-
<link href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" rel="stylesheet">
|
| 192 |
-
<div class="container">
|
| 193 |
-
<div class="row mt-4">
|
| 194 |
-
<div class="col">
|
| 195 |
-
<h5>SwinV2/detect</h5>
|
| 196 |
-
<p>{results.get("SwinV2/detect", "N/A")}</p>
|
| 197 |
-
</div>
|
| 198 |
-
<div class="col">
|
| 199 |
-
<h5>ViT/AI-vs-Real</h5>
|
| 200 |
-
<p>{results.get("ViT/AI-vs-Real", "N/A")}</p>
|
| 201 |
-
</div>
|
| 202 |
-
<div class="col">
|
| 203 |
-
<h5>Swin/SDXL</h5>
|
| 204 |
-
<p>{results.get("Swin/SDXL", "N/A")}</p>
|
| 205 |
-
</div>
|
| 206 |
-
<div class="col">
|
| 207 |
-
<h5>Swin/SDXL-FLUX</h5>
|
| 208 |
-
<p>{results.get("Swin/SDXL-FLUX", "N/A")}</p>
|
| 209 |
-
</div>
|
| 210 |
-
<div class="col">
|
| 211 |
-
<h5>GOAT</h5>
|
| 212 |
-
<p>{results.get("GOAT", "N/A")}</p>
|
| 213 |
-
</div>
|
| 214 |
-
</div>
|
| 215 |
-
</div>
|
| 216 |
-
"""
|
| 217 |
-
return html_content
|
| 218 |
-
|
| 219 |
-
# Modify the predict_image function to return the HTML content
|
| 220 |
-
def predict_image_with_html(img, confidence_threshold):
|
| 221 |
-
img_pil, results = predict_image(img, confidence_threshold)
|
| 222 |
-
html_content = generate_results_html(results)
|
| 223 |
-
return img_pil, html_content
|
| 224 |
-
|
| 225 |
# Launch the interface
|
| 226 |
iface.launch()
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
from utils.goat import call_inference
|
| 9 |
import io
|
| 10 |
+
import warnings
|
| 11 |
+
|
| 12 |
+
# Suppress warnings
|
| 13 |
+
warnings.filterwarnings("ignore", category=UserWarning, message="Using a slow image processor as `use_fast` is unset")
|
| 14 |
|
| 15 |
# Ensure using GPU if available
|
| 16 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
| 172 |
}
|
| 173 |
return img_pil, combined_results
|
| 174 |
|
| 175 |
+
# Define a function to generate the HTML content
|
| 176 |
+
def generate_results_html(results):
|
| 177 |
+
html_content = f"""
|
| 178 |
+
<link href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" rel="stylesheet">
|
| 179 |
+
<div class="container">
|
| 180 |
+
<div class="row mt-4">
|
| 181 |
+
<div class="col">
|
| 182 |
+
<h5>SwinV2/detect</h5>
|
| 183 |
+
<p>{results.get("SwinV2/detect", "N/A")}</p>
|
| 184 |
+
</div>
|
| 185 |
+
<div class="col">
|
| 186 |
+
<h5>ViT/AI-vs-Real</h5>
|
| 187 |
+
<p>{results.get("ViT/AI-vs-Real", "N/A")}</p>
|
| 188 |
+
</div>
|
| 189 |
+
<div class="col">
|
| 190 |
+
<h5>Swin/SDXL</h5>
|
| 191 |
+
<p>{results.get("Swin/SDXL", "N/A")}</p>
|
| 192 |
+
</div>
|
| 193 |
+
<div class="col">
|
| 194 |
+
<h5>Swin/SDXL-FLUX</h5>
|
| 195 |
+
<p>{results.get("Swin/SDXL-FLUX", "N/A")}</p>
|
| 196 |
+
</div>
|
| 197 |
+
<div class="col">
|
| 198 |
+
<h5>GOAT</h5>
|
| 199 |
+
<p>{results.get("GOAT", "N/A")}</p>
|
| 200 |
+
</div>
|
| 201 |
+
</div>
|
| 202 |
+
</div>
|
| 203 |
+
"""
|
| 204 |
+
return html_content
|
| 205 |
+
|
| 206 |
+
# Modify the predict_image function to return the HTML content
|
| 207 |
+
def predict_image_with_html(img, confidence_threshold):
|
| 208 |
+
img_pil, results = predict_image(img, confidence_threshold)
|
| 209 |
+
html_content = generate_results_html(results)
|
| 210 |
+
return img_pil, html_content
|
| 211 |
+
|
| 212 |
# Define the Gradio interface
|
| 213 |
with gr.Blocks() as iface:
|
| 214 |
gr.Markdown("# AI Generated Image Classification")
|
|
|
|
| 226 |
|
| 227 |
gr.Button("Predict").click(fn=predict_image_with_html, inputs=inputs, outputs=outputs)
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
# Launch the interface
|
| 230 |
iface.launch()
|