Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
""" | |
Swin/CAFormer/DINOv2 AI detection | |
------------------------------------------------------------------- | |
• Swin-V2 / V4 : 2-class (AI vs. Non-AI) | |
• Swin-V7 / V8 / V9 : 4-class (photo / anime × AI / Non-AI) | |
• CAFormer-V10 : 4-class (photo / anime × AI / Non-AI) | |
• DINOv2-4class : 4-class (photo / anime × AI / Non-AI) | |
• DINOv2-MeanPool-Contrastive : 4-class (photo / anime × AI / Non-AI) | |
------------------------------------------------------------------- | |
Author: telecomadm1145 | |
""" | |
import os, torch, timm, numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from PIL import Image | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file # Added for .safetensors support | |
# Added for DINOv2 model | |
from transformers import AutoModel | |
from torchvision import transforms | |
# -------------------------------------------------- | |
# 1. Model & Checkpoint Meta-data | |
# -------------------------------------------------- | |
REPO_ID = "telecomadm1145/swin-ai-detection" | |
HF_FILENAMES = { | |
"V2": "swin_classifier_stage1_v2_epoch_3.pth", | |
"V4": "swin_classifier_stage1_v4.pth", | |
#"V5(underfitting)": "swin_classifier_stage1_v5_fp16.pth", | |
"V7": "swin_classifier_4class_fp16_v7.pth", | |
"V8": "swin_classifier_4class_fp16_v8_epoch7_acc9740.pth", | |
"V9": "swin_classifier_4class_fp16_v9_acc9861.pth", | |
"V1-CAFormer": "caformer_b36_4class.safetensors", | |
"V2-CAFormer": "caformer_b36_4class_95.safetensors", | |
"V2.5-CAFormer": "caformer_b36_4class_96.safetensors", | |
"DINOv2-4class": "dinov2_4class.safetensors", | |
# Added new DINOv2 checkpoint filename | |
"DINOv2-MeanPool-Contrastive": "dinov2-base-4class-contrastive_epoch4.safetensors", | |
} | |
CKPT_META = { | |
"V2": { "n_cls": 2, "head": "v4", "backbone": "swin_large_patch4_window12_384", | |
"names": ["Non-AI Generated", "AI Generated"]}, | |
"V4": { "n_cls": 2, "head": "v4", "backbone": "swin_large_patch4_window12_384", | |
"names": ["Non-AI Generated", "AI Generated"]}, | |
#"V5(underfitting)": { "n_cls": 2, "head": "v5", "backbone": "swin_large_patch4_window12_384", | |
# "names": ["Non-AI Generated", "AI Generated"]}, | |
"V7": { "n_cls": 4, "head": "v7", "backbone": "swin_large_patch4_window12_384", | |
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]}, | |
"V8": { "n_cls": 4, "head": "v7", "backbone": "swin_large_patch4_window12_384", | |
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]}, | |
"V9": { "n_cls": 4, "head": "v7", "backbone": "swin_large_patch4_window12_384", | |
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]}, | |
"V1-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384", | |
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]}, | |
"V2-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384", | |
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]}, | |
"V2.5-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384", | |
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]}, | |
# Updated original DINOv2 metadata with a specific model_type | |
"DINOv2-4class": { | |
"model_type": "dinov2_weighted_pool", | |
"backbone": 'facebook/dinov2-base', | |
"n_cls": 4, | |
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"] | |
}, | |
# Added new DINOv2 model metadata | |
"DINOv2-MeanPool-Contrastive": { | |
"model_type": "dinov2_mean_pool", | |
"backbone": 'facebook/dinov2-base', | |
"n_cls": 4, | |
"names": ["non_ai", "ai", "ani_non_ai", "ani_ai"] | |
} | |
} | |
DEFAULT_CKPT = "V1-CAFormer" | |
LOCAL_CKPT_DIR = "./checkpoints" | |
SEED = 4421 | |
DROP_RATE = 0.1 | |
DROPOUT_RATE = 0.1 # From train.py for DINOv2 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch.manual_seed(SEED); np.random.seed(SEED) | |
print(f"Using device: {device}") | |
model, current_ckpt = None, None | |
current_meta = None | |
# --- Original DINOv2 Classifier (Weighted Attention Pooling) --- | |
class DINOv2Classifier_WeightedPool(nn.Module): | |
def __init__(self, model_name, num_classes): | |
super().__init__() | |
self.backbone = AutoModel.from_pretrained(model_name) | |
self.weight_self_attention = nn.MultiheadAttention( | |
embed_dim=self.backbone.config.hidden_size, | |
num_heads=self.backbone.config.num_attention_heads, | |
dropout=self.backbone.config.hidden_dropout_prob, | |
batch_first=True | |
) | |
self.weight_mlp = nn.Sequential( | |
nn.Linear(self.backbone.config.hidden_size, self.backbone.config.hidden_size * 4), | |
nn.LayerNorm(self.backbone.config.hidden_size * 4), | |
nn.GELU(), | |
nn.Linear(self.backbone.config.hidden_size * 4, 1) | |
) | |
self.classifier = nn.Sequential( | |
nn.Dropout(DROPOUT_RATE), | |
nn.Linear(self.backbone.config.hidden_size, self.backbone.config.hidden_size), | |
nn.LayerNorm(self.backbone.config.hidden_size), | |
nn.GELU(), | |
nn.Dropout(DROPOUT_RATE), | |
nn.Linear(self.backbone.config.hidden_size, num_classes) | |
) | |
nn.init.xavier_uniform_(self.weight_self_attention.in_proj_weight) | |
nn.init.xavier_uniform_(self.weight_self_attention.out_proj.weight) | |
nn.init.constant_(self.weight_self_attention.out_proj.bias, 0) | |
for module in [self.weight_mlp, self.classifier]: | |
if isinstance(module, nn.Linear): | |
nn.init.xavier_uniform_(module.weight) | |
nn.init.constant_(module.bias, 0) | |
def forward(self, x): | |
outputs = self.backbone(x) | |
attn_output, _ = self.weight_self_attention( | |
outputs.last_hidden_state, | |
outputs.last_hidden_state, | |
outputs.last_hidden_state, | |
) | |
raw_weights = self.weight_mlp(attn_output) | |
raw_weights = raw_weights.squeeze(-1) | |
pooling_weights = torch.softmax(raw_weights, dim=-1) | |
pooled_output = torch.sum(outputs.last_hidden_state * pooling_weights.unsqueeze(-1), dim=1) | |
return self.classifier(pooled_output) | |
# --- New DINOv2 Classifier (Mean Pooling) --- | |
class DINOv2Classifier_MeanPool(nn.Module): | |
def __init__(self, model_name, num_classes): | |
super().__init__() | |
self.backbone = AutoModel.from_pretrained(model_name) | |
self.classifier = nn.Sequential( | |
nn.Dropout(DROPOUT_RATE), | |
nn.Linear(self.backbone.config.hidden_size, self.backbone.config.hidden_size), | |
nn.LayerNorm(self.backbone.config.hidden_size), | |
nn.GELU(), | |
nn.Dropout(DROPOUT_RATE), | |
nn.Linear(self.backbone.config.hidden_size, num_classes) | |
) | |
for module in self.classifier: | |
if isinstance(module, nn.Linear): | |
nn.init.xavier_uniform_(module.weight) | |
nn.init.constant_(module.bias, 0) | |
def forward(self, x, return_features=False): | |
outputs = self.backbone(x) | |
pooled_output = outputs.last_hidden_state.mean(dim=1) | |
if return_features: | |
return pooled_output | |
return self.classifier(pooled_output) | |
class SwinClassifier(nn.Module): | |
def __init__(self, model_name, num_classes, pretrained=True, | |
head_version="v4"): | |
super().__init__() | |
self.backbone = timm.create_model( | |
model_name, pretrained=pretrained, num_classes=0 | |
) | |
self.data_config = timm.data.resolve_data_config({}, model=self.backbone) | |
# ------- 根据版本选择不同 head ------- | |
if head_version == "v7": # <-- V7, V8, V9, V10: 极简 64-hidden, GELU | |
self.classifier = nn.Sequential( | |
nn.Dropout(DROP_RATE), | |
nn.Linear(self.backbone.num_features, 64), | |
nn.BatchNorm1d(64), | |
nn.GELU(), | |
nn.Dropout(DROP_RATE * 0.8), | |
nn.Linear(64, num_classes), | |
) | |
elif head_version == "v5": # V5: 512-128, GELU | |
self.classifier = nn.Sequential( | |
nn.Dropout(DROP_RATE), | |
nn.Linear(self.backbone.num_features, 512), | |
nn.BatchNorm1d(512), | |
nn.GELU(), | |
nn.Dropout(DROP_RATE * 0.7), | |
nn.Linear(512, 128), | |
nn.BatchNorm1d(128), | |
nn.GELU(), | |
nn.Dropout(DROP_RATE * 0.5), | |
nn.Linear(128, num_classes), | |
) | |
else: # V2 / V4: 512-128, ReLU | |
self.classifier = nn.Sequential( | |
nn.Dropout(DROP_RATE), | |
nn.Linear(self.backbone.num_features, 512), | |
nn.BatchNorm1d(512), | |
nn.ReLU(), | |
nn.Dropout(DROP_RATE * 0.7), | |
nn.Linear(512, 128), | |
nn.BatchNorm1d(128), | |
nn.ReLU(), | |
nn.Dropout(DROP_RATE * 0.5), | |
nn.Linear(128, num_classes), | |
) | |
def forward(self, x): | |
return self.classifier(self.backbone(x)) | |
# -------------------------------------------------- | |
# 4. 动态加载模型 | |
# -------------------------------------------------- | |
def load_model(ckpt_name: str): | |
"""Load model only when `ckpt_name` changes.""" | |
global model, current_ckpt, current_meta | |
if ckpt_name == current_ckpt and model is not None: | |
return | |
print(f"\n🔄 Switching to {ckpt_name} ...") | |
meta = CKPT_META[ckpt_name] | |
ckpt_filename = HF_FILENAMES[ckpt_name] | |
ckpt_file = hf_hub_download( | |
repo_id=REPO_ID, | |
filename=ckpt_filename, | |
local_dir=LOCAL_CKPT_DIR, force_download=False | |
) | |
print(f"Checkpoint: {ckpt_file}") | |
# Build model structure based on model_type | |
model_type = meta.get("model_type") | |
if model_type == "dinov2_weighted_pool": | |
model = DINOv2Classifier_WeightedPool( | |
model_name=meta["backbone"], | |
num_classes=meta["n_cls"] | |
).to(device) | |
elif model_type == "dinov2_mean_pool": | |
model = DINOv2Classifier_MeanPool( | |
model_name=meta["backbone"], | |
num_classes=meta["n_cls"] | |
).to(device) | |
else: # Existing logic for Swin/CAFormer | |
model = SwinClassifier( | |
meta["backbone"], | |
num_classes=meta["n_cls"], | |
pretrained=False, | |
head_version=meta.get("head", "v4") | |
).to(device) | |
# Compatible load for .pth and .safetensors | |
if ckpt_filename.endswith(".safetensors"): | |
state = load_file(ckpt_file, device=device) | |
else: | |
state = torch.load(ckpt_file, map_location=device, weights_only=False) | |
model.load_state_dict(state.get("model_state_dict", state), strict=True) | |
model.eval() | |
current_ckpt, current_meta = ckpt_name, meta | |
print(f"✅ {ckpt_name} loaded (classes = {meta['n_cls']}).") | |
# -------------------------------------------------- | |
# 5. Transform 工厂 | |
# -------------------------------------------------- | |
def build_transform(is_training: bool, interpolation: str): | |
if model is None: raise RuntimeError("Model not loaded yet.") | |
cfg = model.data_config.copy() | |
cfg.update(dict(interpolation=interpolation)) | |
return timm.data.create_transform(**cfg, is_training=is_training) | |
# -------------------------------------------------- | |
# 6. Inference | |
# -------------------------------------------------- | |
def predict(image: Image.Image, | |
ckpt_name: str, | |
interpolation: str = "bicubic"): | |
if image is None: return None | |
load_model(ckpt_name) | |
# Select transform based on the current model type | |
if "dinov2" in current_meta.get("model_type", ""): | |
# DINOv2 specific transform | |
tfm = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
else: | |
# Original transform logic for Swin/CAFormer | |
tfm = build_transform(False, interpolation) | |
inp = tfm(image).unsqueeze(0).to(device) | |
probs = F.softmax(model(inp), dim=1)[0].cpu() | |
class_names = current_meta["names"] | |
# 保证 gr.Label 在 2 / 4 类都能正常显示 | |
return {class_names[i]: float(probs[i]) | |
for i in range(len(class_names))} | |
# -------------------------------------------------- | |
# 7. Gradio UI | |
# -------------------------------------------------- | |
def launch(): | |
load_model(DEFAULT_CKPT) # 预加载 | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# AI Detector") | |
gr.Markdown( | |
"Choose a model checkpoint on the left, upload an image, " | |
"and click **Run** to see predictions. Checkpoint V7+ and all DINOv2 models output 4 classes." | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
run_btn = gr.Button("🚀 Run", variant="primary") | |
sel_ckpt = gr.Dropdown( | |
list(HF_FILENAMES.keys()), | |
value=DEFAULT_CKPT, label="Checkpoint" | |
) | |
sel_interp = gr.Radio( | |
["bilinear", "bicubic", "nearest"], | |
value="bicubic", label="Resize Interpolation (for Swin/CAFormer)" | |
) | |
in_img = gr.Image(type="pil", label="Upload Image") | |
with gr.Column(scale=1): | |
# num_top_classes 设为 4,兼容 2-class / 4-class | |
out_lbl = gr.Label(num_top_classes=4, label="Predictions") | |
run_btn.click( | |
predict, | |
inputs=[in_img, sel_ckpt, sel_interp], | |
outputs=[out_lbl] | |
) | |
# optional example folder | |
if not os.path.exists("examples"): | |
os.makedirs("examples") | |
print("Put some jpg/png files inside ./examples for demo examples") | |
example_files = [os.path.join("examples", f) | |
for f in os.listdir("examples") | |
if f.lower().endswith(('.png', '.jpg', '.jpeg'))] | |
if example_files: | |
gr.Examples( | |
examples=[[f, DEFAULT_CKPT, "bicubic"] for f in example_files], | |
inputs=[in_img, sel_ckpt, sel_interp], | |
outputs=[out_lbl], | |
fn=predict, | |
cache_examples=False, | |
) | |
demo.launch() | |
# -------------------------------------------------- | |
if __name__ == "__main__": | |
launch() |