Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ Swin/CAFormer/DINOv2 AI detection
|
|
6 |
• Swin-V7 / V8 / V9 : 4-class (photo / anime × AI / Non-AI)
|
7 |
• CAFormer-V10 : 4-class (photo / anime × AI / Non-AI)
|
8 |
• DINOv2-4class : 4-class (photo / anime × AI / Non-AI)
|
|
|
9 |
-------------------------------------------------------------------
|
10 |
Author: telecomadm1145
|
11 |
"""
|
@@ -35,7 +36,9 @@ HF_FILENAMES = {
|
|
35 |
"V1-CAFormer": "caformer_b36_4class.safetensors",
|
36 |
"V2-CAFormer": "caformer_b36_4class_95.safetensors",
|
37 |
"V2.5-CAFormer": "caformer_b36_4class_96.safetensors",
|
38 |
-
"DINOv2-4class": "dinov2_4class.safetensors",
|
|
|
|
|
39 |
}
|
40 |
|
41 |
CKPT_META = {
|
@@ -57,13 +60,20 @@ CKPT_META = {
|
|
57 |
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
|
58 |
"V2.5-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384",
|
59 |
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
|
60 |
-
#
|
61 |
"DINOv2-4class": {
|
62 |
-
"model_type": "
|
63 |
"backbone": 'facebook/dinov2-base',
|
64 |
"n_cls": 4,
|
65 |
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]
|
66 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
}
|
68 |
|
69 |
DEFAULT_CKPT = "V1-CAFormer"
|
@@ -79,8 +89,8 @@ print(f"Using device: {device}")
|
|
79 |
model, current_ckpt = None, None
|
80 |
current_meta = None
|
81 |
|
82 |
-
# ---
|
83 |
-
class
|
84 |
def __init__(self, model_name, num_classes):
|
85 |
super().__init__()
|
86 |
self.backbone = AutoModel.from_pretrained(model_name)
|
@@ -125,10 +135,36 @@ class DINOv2Classifier(nn.Module):
|
|
125 |
pooling_weights = torch.softmax(raw_weights, dim=-1)
|
126 |
pooled_output = torch.sum(outputs.last_hidden_state * pooling_weights.unsqueeze(-1), dim=1)
|
127 |
return self.classifier(pooled_output)
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
|
131 |
-
# Renamed to ImageClassifier for clarity, but keeping original name to avoid breaking changes if subclassed elsewhere.
|
132 |
class SwinClassifier(nn.Module):
|
133 |
def __init__(self, model_name, num_classes, pretrained=True,
|
134 |
head_version="v4"):
|
@@ -195,8 +231,6 @@ def load_model(ckpt_name: str):
|
|
195 |
meta = CKPT_META[ckpt_name]
|
196 |
ckpt_filename = HF_FILENAMES[ckpt_name]
|
197 |
|
198 |
-
# Check if the checkpoint is DINOv2 and handle its local path
|
199 |
-
# Download other models from HF Hub
|
200 |
ckpt_file = hf_hub_download(
|
201 |
repo_id=REPO_ID,
|
202 |
filename=ckpt_filename,
|
@@ -205,8 +239,14 @@ def load_model(ckpt_name: str):
|
|
205 |
print(f"Checkpoint: {ckpt_file}")
|
206 |
|
207 |
# Build model structure based on model_type
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
model_name=meta["backbone"],
|
211 |
num_classes=meta["n_cls"]
|
212 |
).to(device)
|
@@ -253,8 +293,8 @@ def predict(image: Image.Image,
|
|
253 |
load_model(ckpt_name)
|
254 |
|
255 |
# Select transform based on the current model type
|
256 |
-
if current_meta.get("model_type"
|
257 |
-
# DINOv2 specific transform
|
258 |
tfm = transforms.Compose([
|
259 |
transforms.Resize((224, 224)),
|
260 |
transforms.ToTensor(),
|
@@ -283,7 +323,7 @@ def launch():
|
|
283 |
gr.Markdown("# AI Detector")
|
284 |
gr.Markdown(
|
285 |
"Choose a model checkpoint on the left, upload an image, "
|
286 |
-
"and click **Run** to see predictions. Checkpoint V7+ and DINOv2
|
287 |
)
|
288 |
|
289 |
with gr.Row():
|
|
|
6 |
• Swin-V7 / V8 / V9 : 4-class (photo / anime × AI / Non-AI)
|
7 |
• CAFormer-V10 : 4-class (photo / anime × AI / Non-AI)
|
8 |
• DINOv2-4class : 4-class (photo / anime × AI / Non-AI)
|
9 |
+
• DINOv2-MeanPool-Contrastive : 4-class (photo / anime × AI / Non-AI)
|
10 |
-------------------------------------------------------------------
|
11 |
Author: telecomadm1145
|
12 |
"""
|
|
|
36 |
"V1-CAFormer": "caformer_b36_4class.safetensors",
|
37 |
"V2-CAFormer": "caformer_b36_4class_95.safetensors",
|
38 |
"V2.5-CAFormer": "caformer_b36_4class_96.safetensors",
|
39 |
+
"DINOv2-4class": "dinov2_4class.safetensors",
|
40 |
+
# Added new DINOv2 checkpoint filename
|
41 |
+
"DINOv2-MeanPool-Contrastive": "dinov2-base-4class-contrastive_epoch4.safetensors",
|
42 |
}
|
43 |
|
44 |
CKPT_META = {
|
|
|
60 |
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
|
61 |
"V2.5-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384",
|
62 |
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
|
63 |
+
# Updated original DINOv2 metadata with a specific model_type
|
64 |
"DINOv2-4class": {
|
65 |
+
"model_type": "dinov2_weighted_pool",
|
66 |
"backbone": 'facebook/dinov2-base',
|
67 |
"n_cls": 4,
|
68 |
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]
|
69 |
},
|
70 |
+
# Added new DINOv2 model metadata
|
71 |
+
"DINOv2-MeanPool-Contrastive": {
|
72 |
+
"model_type": "dinov2_mean_pool",
|
73 |
+
"backbone": 'facebook/dinov2-base',
|
74 |
+
"n_cls": 4,
|
75 |
+
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]
|
76 |
+
}
|
77 |
}
|
78 |
|
79 |
DEFAULT_CKPT = "V1-CAFormer"
|
|
|
89 |
model, current_ckpt = None, None
|
90 |
current_meta = None
|
91 |
|
92 |
+
# --- Original DINOv2 Classifier (Weighted Attention Pooling) ---
|
93 |
+
class DINOv2Classifier_WeightedPool(nn.Module):
|
94 |
def __init__(self, model_name, num_classes):
|
95 |
super().__init__()
|
96 |
self.backbone = AutoModel.from_pretrained(model_name)
|
|
|
135 |
pooling_weights = torch.softmax(raw_weights, dim=-1)
|
136 |
pooled_output = torch.sum(outputs.last_hidden_state * pooling_weights.unsqueeze(-1), dim=1)
|
137 |
return self.classifier(pooled_output)
|
138 |
+
|
139 |
+
# --- New DINOv2 Classifier (Mean Pooling) ---
|
140 |
+
class DINOv2Classifier_MeanPool(nn.Module):
|
141 |
+
def __init__(self, model_name, num_classes):
|
142 |
+
super().__init__()
|
143 |
+
self.backbone = AutoModel.from_pretrained(model_name)
|
144 |
+
self.classifier = nn.Sequential(
|
145 |
+
nn.Dropout(DROPOUT_RATE),
|
146 |
+
nn.Linear(self.backbone.config.hidden_size, self.backbone.config.hidden_size),
|
147 |
+
nn.LayerNorm(self.backbone.config.hidden_size),
|
148 |
+
nn.GELU(),
|
149 |
+
nn.Dropout(DROPOUT_RATE),
|
150 |
+
nn.Linear(self.backbone.config.hidden_size, num_classes)
|
151 |
+
)
|
152 |
+
|
153 |
+
for module in self.classifier:
|
154 |
+
if isinstance(module, nn.Linear):
|
155 |
+
nn.init.xavier_uniform_(module.weight)
|
156 |
+
nn.init.constant_(module.bias, 0)
|
157 |
+
|
158 |
+
def forward(self, x, return_features=False):
|
159 |
+
outputs = self.backbone(x)
|
160 |
+
pooled_output = outputs.last_hidden_state.mean(dim=1)
|
161 |
+
|
162 |
+
if return_features:
|
163 |
+
return pooled_output
|
164 |
+
|
165 |
+
return self.classifier(pooled_output)
|
166 |
|
167 |
|
|
|
168 |
class SwinClassifier(nn.Module):
|
169 |
def __init__(self, model_name, num_classes, pretrained=True,
|
170 |
head_version="v4"):
|
|
|
231 |
meta = CKPT_META[ckpt_name]
|
232 |
ckpt_filename = HF_FILENAMES[ckpt_name]
|
233 |
|
|
|
|
|
234 |
ckpt_file = hf_hub_download(
|
235 |
repo_id=REPO_ID,
|
236 |
filename=ckpt_filename,
|
|
|
239 |
print(f"Checkpoint: {ckpt_file}")
|
240 |
|
241 |
# Build model structure based on model_type
|
242 |
+
model_type = meta.get("model_type")
|
243 |
+
if model_type == "dinov2_weighted_pool":
|
244 |
+
model = DINOv2Classifier_WeightedPool(
|
245 |
+
model_name=meta["backbone"],
|
246 |
+
num_classes=meta["n_cls"]
|
247 |
+
).to(device)
|
248 |
+
elif model_type == "dinov2_mean_pool":
|
249 |
+
model = DINOv2Classifier_MeanPool(
|
250 |
model_name=meta["backbone"],
|
251 |
num_classes=meta["n_cls"]
|
252 |
).to(device)
|
|
|
293 |
load_model(ckpt_name)
|
294 |
|
295 |
# Select transform based on the current model type
|
296 |
+
if "dinov2" in current_meta.get("model_type", ""):
|
297 |
+
# DINOv2 specific transform
|
298 |
tfm = transforms.Compose([
|
299 |
transforms.Resize((224, 224)),
|
300 |
transforms.ToTensor(),
|
|
|
323 |
gr.Markdown("# AI Detector")
|
324 |
gr.Markdown(
|
325 |
"Choose a model checkpoint on the left, upload an image, "
|
326 |
+
"and click **Run** to see predictions. Checkpoint V7+ and all DINOv2 models output 4 classes."
|
327 |
)
|
328 |
|
329 |
with gr.Row():
|