telecomadm1145 commited on
Commit
2ffd4a9
·
verified ·
1 Parent(s): edb115a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -14
app.py CHANGED
@@ -6,6 +6,7 @@ Swin/CAFormer/DINOv2 AI detection
6
  • Swin-V7 / V8 / V9 : 4-class (photo / anime × AI / Non-AI)
7
  • CAFormer-V10 : 4-class (photo / anime × AI / Non-AI)
8
  • DINOv2-4class : 4-class (photo / anime × AI / Non-AI)
 
9
  -------------------------------------------------------------------
10
  Author: telecomadm1145
11
  """
@@ -35,7 +36,9 @@ HF_FILENAMES = {
35
  "V1-CAFormer": "caformer_b36_4class.safetensors",
36
  "V2-CAFormer": "caformer_b36_4class_95.safetensors",
37
  "V2.5-CAFormer": "caformer_b36_4class_96.safetensors",
38
- "DINOv2-4class": "dinov2_4class.safetensors", # Added DINOv2 checkpoint
 
 
39
  }
40
 
41
  CKPT_META = {
@@ -57,13 +60,20 @@ CKPT_META = {
57
  "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
58
  "V2.5-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384",
59
  "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
60
- # Added DINOv2 metadata
61
  "DINOv2-4class": {
62
- "model_type": "dinov2",
63
  "backbone": 'facebook/dinov2-base',
64
  "n_cls": 4,
65
  "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]
66
  },
 
 
 
 
 
 
 
67
  }
68
 
69
  DEFAULT_CKPT = "V1-CAFormer"
@@ -79,8 +89,8 @@ print(f"Using device: {device}")
79
  model, current_ckpt = None, None
80
  current_meta = None
81
 
82
- # --- Start of code from train.py ---
83
- class DINOv2Classifier(nn.Module):
84
  def __init__(self, model_name, num_classes):
85
  super().__init__()
86
  self.backbone = AutoModel.from_pretrained(model_name)
@@ -125,10 +135,36 @@ class DINOv2Classifier(nn.Module):
125
  pooling_weights = torch.softmax(raw_weights, dim=-1)
126
  pooled_output = torch.sum(outputs.last_hidden_state * pooling_weights.unsqueeze(-1), dim=1)
127
  return self.classifier(pooled_output)
128
- # --- End of code from train.py ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
 
131
- # Renamed to ImageClassifier for clarity, but keeping original name to avoid breaking changes if subclassed elsewhere.
132
  class SwinClassifier(nn.Module):
133
  def __init__(self, model_name, num_classes, pretrained=True,
134
  head_version="v4"):
@@ -195,8 +231,6 @@ def load_model(ckpt_name: str):
195
  meta = CKPT_META[ckpt_name]
196
  ckpt_filename = HF_FILENAMES[ckpt_name]
197
 
198
- # Check if the checkpoint is DINOv2 and handle its local path
199
- # Download other models from HF Hub
200
  ckpt_file = hf_hub_download(
201
  repo_id=REPO_ID,
202
  filename=ckpt_filename,
@@ -205,8 +239,14 @@ def load_model(ckpt_name: str):
205
  print(f"Checkpoint: {ckpt_file}")
206
 
207
  # Build model structure based on model_type
208
- if meta.get("model_type") == "dinov2":
209
- model = DINOv2Classifier(
 
 
 
 
 
 
210
  model_name=meta["backbone"],
211
  num_classes=meta["n_cls"]
212
  ).to(device)
@@ -253,8 +293,8 @@ def predict(image: Image.Image,
253
  load_model(ckpt_name)
254
 
255
  # Select transform based on the current model type
256
- if current_meta.get("model_type") == "dinov2":
257
- # DINOv2 specific transform from train.py
258
  tfm = transforms.Compose([
259
  transforms.Resize((224, 224)),
260
  transforms.ToTensor(),
@@ -283,7 +323,7 @@ def launch():
283
  gr.Markdown("# AI Detector")
284
  gr.Markdown(
285
  "Choose a model checkpoint on the left, upload an image, "
286
- "and click **Run** to see predictions. Checkpoint V7+ and DINOv2 outputs 4 classes."
287
  )
288
 
289
  with gr.Row():
 
6
  • Swin-V7 / V8 / V9 : 4-class (photo / anime × AI / Non-AI)
7
  • CAFormer-V10 : 4-class (photo / anime × AI / Non-AI)
8
  • DINOv2-4class : 4-class (photo / anime × AI / Non-AI)
9
+ • DINOv2-MeanPool-Contrastive : 4-class (photo / anime × AI / Non-AI)
10
  -------------------------------------------------------------------
11
  Author: telecomadm1145
12
  """
 
36
  "V1-CAFormer": "caformer_b36_4class.safetensors",
37
  "V2-CAFormer": "caformer_b36_4class_95.safetensors",
38
  "V2.5-CAFormer": "caformer_b36_4class_96.safetensors",
39
+ "DINOv2-4class": "dinov2_4class.safetensors",
40
+ # Added new DINOv2 checkpoint filename
41
+ "DINOv2-MeanPool-Contrastive": "dinov2-base-4class-contrastive_epoch4.safetensors",
42
  }
43
 
44
  CKPT_META = {
 
60
  "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
61
  "V2.5-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384",
62
  "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
63
+ # Updated original DINOv2 metadata with a specific model_type
64
  "DINOv2-4class": {
65
+ "model_type": "dinov2_weighted_pool",
66
  "backbone": 'facebook/dinov2-base',
67
  "n_cls": 4,
68
  "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]
69
  },
70
+ # Added new DINOv2 model metadata
71
+ "DINOv2-MeanPool-Contrastive": {
72
+ "model_type": "dinov2_mean_pool",
73
+ "backbone": 'facebook/dinov2-base',
74
+ "n_cls": 4,
75
+ "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]
76
+ }
77
  }
78
 
79
  DEFAULT_CKPT = "V1-CAFormer"
 
89
  model, current_ckpt = None, None
90
  current_meta = None
91
 
92
+ # --- Original DINOv2 Classifier (Weighted Attention Pooling) ---
93
+ class DINOv2Classifier_WeightedPool(nn.Module):
94
  def __init__(self, model_name, num_classes):
95
  super().__init__()
96
  self.backbone = AutoModel.from_pretrained(model_name)
 
135
  pooling_weights = torch.softmax(raw_weights, dim=-1)
136
  pooled_output = torch.sum(outputs.last_hidden_state * pooling_weights.unsqueeze(-1), dim=1)
137
  return self.classifier(pooled_output)
138
+
139
+ # --- New DINOv2 Classifier (Mean Pooling) ---
140
+ class DINOv2Classifier_MeanPool(nn.Module):
141
+ def __init__(self, model_name, num_classes):
142
+ super().__init__()
143
+ self.backbone = AutoModel.from_pretrained(model_name)
144
+ self.classifier = nn.Sequential(
145
+ nn.Dropout(DROPOUT_RATE),
146
+ nn.Linear(self.backbone.config.hidden_size, self.backbone.config.hidden_size),
147
+ nn.LayerNorm(self.backbone.config.hidden_size),
148
+ nn.GELU(),
149
+ nn.Dropout(DROPOUT_RATE),
150
+ nn.Linear(self.backbone.config.hidden_size, num_classes)
151
+ )
152
+
153
+ for module in self.classifier:
154
+ if isinstance(module, nn.Linear):
155
+ nn.init.xavier_uniform_(module.weight)
156
+ nn.init.constant_(module.bias, 0)
157
+
158
+ def forward(self, x, return_features=False):
159
+ outputs = self.backbone(x)
160
+ pooled_output = outputs.last_hidden_state.mean(dim=1)
161
+
162
+ if return_features:
163
+ return pooled_output
164
+
165
+ return self.classifier(pooled_output)
166
 
167
 
 
168
  class SwinClassifier(nn.Module):
169
  def __init__(self, model_name, num_classes, pretrained=True,
170
  head_version="v4"):
 
231
  meta = CKPT_META[ckpt_name]
232
  ckpt_filename = HF_FILENAMES[ckpt_name]
233
 
 
 
234
  ckpt_file = hf_hub_download(
235
  repo_id=REPO_ID,
236
  filename=ckpt_filename,
 
239
  print(f"Checkpoint: {ckpt_file}")
240
 
241
  # Build model structure based on model_type
242
+ model_type = meta.get("model_type")
243
+ if model_type == "dinov2_weighted_pool":
244
+ model = DINOv2Classifier_WeightedPool(
245
+ model_name=meta["backbone"],
246
+ num_classes=meta["n_cls"]
247
+ ).to(device)
248
+ elif model_type == "dinov2_mean_pool":
249
+ model = DINOv2Classifier_MeanPool(
250
  model_name=meta["backbone"],
251
  num_classes=meta["n_cls"]
252
  ).to(device)
 
293
  load_model(ckpt_name)
294
 
295
  # Select transform based on the current model type
296
+ if "dinov2" in current_meta.get("model_type", ""):
297
+ # DINOv2 specific transform
298
  tfm = transforms.Compose([
299
  transforms.Resize((224, 224)),
300
  transforms.ToTensor(),
 
323
  gr.Markdown("# AI Detector")
324
  gr.Markdown(
325
  "Choose a model checkpoint on the left, upload an image, "
326
+ "and click **Run** to see predictions. Checkpoint V7+ and all DINOv2 models output 4 classes."
327
  )
328
 
329
  with gr.Row():