AIDetectV2 / app.py
telecomadm1145's picture
Update app.py
2ffd4a9 verified
# -*- coding: utf-8 -*-
"""
Swin/CAFormer/DINOv2 AI detection
-------------------------------------------------------------------
• Swin-V2 / V4 : 2-class (AI vs. Non-AI)
• Swin-V7 / V8 / V9 : 4-class (photo / anime × AI / Non-AI)
• CAFormer-V10 : 4-class (photo / anime × AI / Non-AI)
• DINOv2-4class : 4-class (photo / anime × AI / Non-AI)
• DINOv2-MeanPool-Contrastive : 4-class (photo / anime × AI / Non-AI)
-------------------------------------------------------------------
Author: telecomadm1145
"""
import os, torch, timm, numpy as np
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file # Added for .safetensors support
# Added for DINOv2 model
from transformers import AutoModel
from torchvision import transforms
# --------------------------------------------------
# 1. Model & Checkpoint Meta-data
# --------------------------------------------------
REPO_ID = "telecomadm1145/swin-ai-detection"
HF_FILENAMES = {
"V2": "swin_classifier_stage1_v2_epoch_3.pth",
"V4": "swin_classifier_stage1_v4.pth",
#"V5(underfitting)": "swin_classifier_stage1_v5_fp16.pth",
"V7": "swin_classifier_4class_fp16_v7.pth",
"V8": "swin_classifier_4class_fp16_v8_epoch7_acc9740.pth",
"V9": "swin_classifier_4class_fp16_v9_acc9861.pth",
"V1-CAFormer": "caformer_b36_4class.safetensors",
"V2-CAFormer": "caformer_b36_4class_95.safetensors",
"V2.5-CAFormer": "caformer_b36_4class_96.safetensors",
"DINOv2-4class": "dinov2_4class.safetensors",
# Added new DINOv2 checkpoint filename
"DINOv2-MeanPool-Contrastive": "dinov2-base-4class-contrastive_epoch4.safetensors",
}
CKPT_META = {
"V2": { "n_cls": 2, "head": "v4", "backbone": "swin_large_patch4_window12_384",
"names": ["Non-AI Generated", "AI Generated"]},
"V4": { "n_cls": 2, "head": "v4", "backbone": "swin_large_patch4_window12_384",
"names": ["Non-AI Generated", "AI Generated"]},
#"V5(underfitting)": { "n_cls": 2, "head": "v5", "backbone": "swin_large_patch4_window12_384",
# "names": ["Non-AI Generated", "AI Generated"]},
"V7": { "n_cls": 4, "head": "v7", "backbone": "swin_large_patch4_window12_384",
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
"V8": { "n_cls": 4, "head": "v7", "backbone": "swin_large_patch4_window12_384",
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
"V9": { "n_cls": 4, "head": "v7", "backbone": "swin_large_patch4_window12_384",
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
"V1-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384",
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
"V2-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384",
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
"V2.5-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384",
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
# Updated original DINOv2 metadata with a specific model_type
"DINOv2-4class": {
"model_type": "dinov2_weighted_pool",
"backbone": 'facebook/dinov2-base',
"n_cls": 4,
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]
},
# Added new DINOv2 model metadata
"DINOv2-MeanPool-Contrastive": {
"model_type": "dinov2_mean_pool",
"backbone": 'facebook/dinov2-base',
"n_cls": 4,
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]
}
}
DEFAULT_CKPT = "V1-CAFormer"
LOCAL_CKPT_DIR = "./checkpoints"
SEED = 4421
DROP_RATE = 0.1
DROPOUT_RATE = 0.1 # From train.py for DINOv2
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(SEED); np.random.seed(SEED)
print(f"Using device: {device}")
model, current_ckpt = None, None
current_meta = None
# --- Original DINOv2 Classifier (Weighted Attention Pooling) ---
class DINOv2Classifier_WeightedPool(nn.Module):
def __init__(self, model_name, num_classes):
super().__init__()
self.backbone = AutoModel.from_pretrained(model_name)
self.weight_self_attention = nn.MultiheadAttention(
embed_dim=self.backbone.config.hidden_size,
num_heads=self.backbone.config.num_attention_heads,
dropout=self.backbone.config.hidden_dropout_prob,
batch_first=True
)
self.weight_mlp = nn.Sequential(
nn.Linear(self.backbone.config.hidden_size, self.backbone.config.hidden_size * 4),
nn.LayerNorm(self.backbone.config.hidden_size * 4),
nn.GELU(),
nn.Linear(self.backbone.config.hidden_size * 4, 1)
)
self.classifier = nn.Sequential(
nn.Dropout(DROPOUT_RATE),
nn.Linear(self.backbone.config.hidden_size, self.backbone.config.hidden_size),
nn.LayerNorm(self.backbone.config.hidden_size),
nn.GELU(),
nn.Dropout(DROPOUT_RATE),
nn.Linear(self.backbone.config.hidden_size, num_classes)
)
nn.init.xavier_uniform_(self.weight_self_attention.in_proj_weight)
nn.init.xavier_uniform_(self.weight_self_attention.out_proj.weight)
nn.init.constant_(self.weight_self_attention.out_proj.bias, 0)
for module in [self.weight_mlp, self.classifier]:
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
nn.init.constant_(module.bias, 0)
def forward(self, x):
outputs = self.backbone(x)
attn_output, _ = self.weight_self_attention(
outputs.last_hidden_state,
outputs.last_hidden_state,
outputs.last_hidden_state,
)
raw_weights = self.weight_mlp(attn_output)
raw_weights = raw_weights.squeeze(-1)
pooling_weights = torch.softmax(raw_weights, dim=-1)
pooled_output = torch.sum(outputs.last_hidden_state * pooling_weights.unsqueeze(-1), dim=1)
return self.classifier(pooled_output)
# --- New DINOv2 Classifier (Mean Pooling) ---
class DINOv2Classifier_MeanPool(nn.Module):
def __init__(self, model_name, num_classes):
super().__init__()
self.backbone = AutoModel.from_pretrained(model_name)
self.classifier = nn.Sequential(
nn.Dropout(DROPOUT_RATE),
nn.Linear(self.backbone.config.hidden_size, self.backbone.config.hidden_size),
nn.LayerNorm(self.backbone.config.hidden_size),
nn.GELU(),
nn.Dropout(DROPOUT_RATE),
nn.Linear(self.backbone.config.hidden_size, num_classes)
)
for module in self.classifier:
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
nn.init.constant_(module.bias, 0)
def forward(self, x, return_features=False):
outputs = self.backbone(x)
pooled_output = outputs.last_hidden_state.mean(dim=1)
if return_features:
return pooled_output
return self.classifier(pooled_output)
class SwinClassifier(nn.Module):
def __init__(self, model_name, num_classes, pretrained=True,
head_version="v4"):
super().__init__()
self.backbone = timm.create_model(
model_name, pretrained=pretrained, num_classes=0
)
self.data_config = timm.data.resolve_data_config({}, model=self.backbone)
# ------- 根据版本选择不同 head -------
if head_version == "v7": # <-- V7, V8, V9, V10: 极简 64-hidden, GELU
self.classifier = nn.Sequential(
nn.Dropout(DROP_RATE),
nn.Linear(self.backbone.num_features, 64),
nn.BatchNorm1d(64),
nn.GELU(),
nn.Dropout(DROP_RATE * 0.8),
nn.Linear(64, num_classes),
)
elif head_version == "v5": # V5: 512-128, GELU
self.classifier = nn.Sequential(
nn.Dropout(DROP_RATE),
nn.Linear(self.backbone.num_features, 512),
nn.BatchNorm1d(512),
nn.GELU(),
nn.Dropout(DROP_RATE * 0.7),
nn.Linear(512, 128),
nn.BatchNorm1d(128),
nn.GELU(),
nn.Dropout(DROP_RATE * 0.5),
nn.Linear(128, num_classes),
)
else: # V2 / V4: 512-128, ReLU
self.classifier = nn.Sequential(
nn.Dropout(DROP_RATE),
nn.Linear(self.backbone.num_features, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(DROP_RATE * 0.7),
nn.Linear(512, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(DROP_RATE * 0.5),
nn.Linear(128, num_classes),
)
def forward(self, x):
return self.classifier(self.backbone(x))
# --------------------------------------------------
# 4. 动态加载模型
# --------------------------------------------------
def load_model(ckpt_name: str):
"""Load model only when `ckpt_name` changes."""
global model, current_ckpt, current_meta
if ckpt_name == current_ckpt and model is not None:
return
print(f"\n🔄 Switching to {ckpt_name} ...")
meta = CKPT_META[ckpt_name]
ckpt_filename = HF_FILENAMES[ckpt_name]
ckpt_file = hf_hub_download(
repo_id=REPO_ID,
filename=ckpt_filename,
local_dir=LOCAL_CKPT_DIR, force_download=False
)
print(f"Checkpoint: {ckpt_file}")
# Build model structure based on model_type
model_type = meta.get("model_type")
if model_type == "dinov2_weighted_pool":
model = DINOv2Classifier_WeightedPool(
model_name=meta["backbone"],
num_classes=meta["n_cls"]
).to(device)
elif model_type == "dinov2_mean_pool":
model = DINOv2Classifier_MeanPool(
model_name=meta["backbone"],
num_classes=meta["n_cls"]
).to(device)
else: # Existing logic for Swin/CAFormer
model = SwinClassifier(
meta["backbone"],
num_classes=meta["n_cls"],
pretrained=False,
head_version=meta.get("head", "v4")
).to(device)
# Compatible load for .pth and .safetensors
if ckpt_filename.endswith(".safetensors"):
state = load_file(ckpt_file, device=device)
else:
state = torch.load(ckpt_file, map_location=device, weights_only=False)
model.load_state_dict(state.get("model_state_dict", state), strict=True)
model.eval()
current_ckpt, current_meta = ckpt_name, meta
print(f"✅ {ckpt_name} loaded (classes = {meta['n_cls']}).")
# --------------------------------------------------
# 5. Transform 工厂
# --------------------------------------------------
def build_transform(is_training: bool, interpolation: str):
if model is None: raise RuntimeError("Model not loaded yet.")
cfg = model.data_config.copy()
cfg.update(dict(interpolation=interpolation))
return timm.data.create_transform(**cfg, is_training=is_training)
# --------------------------------------------------
# 6. Inference
# --------------------------------------------------
@torch.no_grad()
def predict(image: Image.Image,
ckpt_name: str,
interpolation: str = "bicubic"):
if image is None: return None
load_model(ckpt_name)
# Select transform based on the current model type
if "dinov2" in current_meta.get("model_type", ""):
# DINOv2 specific transform
tfm = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
else:
# Original transform logic for Swin/CAFormer
tfm = build_transform(False, interpolation)
inp = tfm(image).unsqueeze(0).to(device)
probs = F.softmax(model(inp), dim=1)[0].cpu()
class_names = current_meta["names"]
# 保证 gr.Label 在 2 / 4 类都能正常显示
return {class_names[i]: float(probs[i])
for i in range(len(class_names))}
# --------------------------------------------------
# 7. Gradio UI
# --------------------------------------------------
def launch():
load_model(DEFAULT_CKPT) # 预加载
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# AI Detector")
gr.Markdown(
"Choose a model checkpoint on the left, upload an image, "
"and click **Run** to see predictions. Checkpoint V7+ and all DINOv2 models output 4 classes."
)
with gr.Row():
with gr.Column(scale=1):
run_btn = gr.Button("🚀 Run", variant="primary")
sel_ckpt = gr.Dropdown(
list(HF_FILENAMES.keys()),
value=DEFAULT_CKPT, label="Checkpoint"
)
sel_interp = gr.Radio(
["bilinear", "bicubic", "nearest"],
value="bicubic", label="Resize Interpolation (for Swin/CAFormer)"
)
in_img = gr.Image(type="pil", label="Upload Image")
with gr.Column(scale=1):
# num_top_classes 设为 4,兼容 2-class / 4-class
out_lbl = gr.Label(num_top_classes=4, label="Predictions")
run_btn.click(
predict,
inputs=[in_img, sel_ckpt, sel_interp],
outputs=[out_lbl]
)
# optional example folder
if not os.path.exists("examples"):
os.makedirs("examples")
print("Put some jpg/png files inside ./examples for demo examples")
example_files = [os.path.join("examples", f)
for f in os.listdir("examples")
if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
if example_files:
gr.Examples(
examples=[[f, DEFAULT_CKPT, "bicubic"] for f in example_files],
inputs=[in_img, sel_ckpt, sel_interp],
outputs=[out_lbl],
fn=predict,
cache_examples=False,
)
demo.launch()
# --------------------------------------------------
if __name__ == "__main__":
launch()