# -*- coding: utf-8 -*- """ Swin/CAFormer/DINOv2 AI detection ------------------------------------------------------------------- • Swin-V2 / V4 : 2-class (AI vs. Non-AI) • Swin-V7 / V8 / V9 : 4-class (photo / anime × AI / Non-AI) • CAFormer-V10 : 4-class (photo / anime × AI / Non-AI) • DINOv2-4class : 4-class (photo / anime × AI / Non-AI) • DINOv2-MeanPool-Contrastive : 4-class (photo / anime × AI / Non-AI) ------------------------------------------------------------------- Author: telecomadm1145 """ import os, torch, timm, numpy as np import torch.nn as nn import torch.nn.functional as F from PIL import Image import gradio as gr from huggingface_hub import hf_hub_download from safetensors.torch import load_file # Added for .safetensors support # Added for DINOv2 model from transformers import AutoModel from torchvision import transforms # -------------------------------------------------- # 1. Model & Checkpoint Meta-data # -------------------------------------------------- REPO_ID = "telecomadm1145/swin-ai-detection" HF_FILENAMES = { "V2": "swin_classifier_stage1_v2_epoch_3.pth", "V4": "swin_classifier_stage1_v4.pth", #"V5(underfitting)": "swin_classifier_stage1_v5_fp16.pth", "V7": "swin_classifier_4class_fp16_v7.pth", "V8": "swin_classifier_4class_fp16_v8_epoch7_acc9740.pth", "V9": "swin_classifier_4class_fp16_v9_acc9861.pth", "V1-CAFormer": "caformer_b36_4class.safetensors", "V2-CAFormer": "caformer_b36_4class_95.safetensors", "V2.5-CAFormer": "caformer_b36_4class_96.safetensors", "DINOv2-4class": "dinov2_4class.safetensors", # Added new DINOv2 checkpoint filename "DINOv2-MeanPool-Contrastive": "dinov2-base-4class-contrastive_epoch4.safetensors", } CKPT_META = { "V2": { "n_cls": 2, "head": "v4", "backbone": "swin_large_patch4_window12_384", "names": ["Non-AI Generated", "AI Generated"]}, "V4": { "n_cls": 2, "head": "v4", "backbone": "swin_large_patch4_window12_384", "names": ["Non-AI Generated", "AI Generated"]}, #"V5(underfitting)": { "n_cls": 2, "head": "v5", "backbone": "swin_large_patch4_window12_384", # "names": ["Non-AI Generated", "AI Generated"]}, "V7": { "n_cls": 4, "head": "v7", "backbone": "swin_large_patch4_window12_384", "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]}, "V8": { "n_cls": 4, "head": "v7", "backbone": "swin_large_patch4_window12_384", "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]}, "V9": { "n_cls": 4, "head": "v7", "backbone": "swin_large_patch4_window12_384", "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]}, "V1-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384", "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]}, "V2-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384", "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]}, "V2.5-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384", "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]}, # Updated original DINOv2 metadata with a specific model_type "DINOv2-4class": { "model_type": "dinov2_weighted_pool", "backbone": 'facebook/dinov2-base', "n_cls": 4, "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"] }, # Added new DINOv2 model metadata "DINOv2-MeanPool-Contrastive": { "model_type": "dinov2_mean_pool", "backbone": 'facebook/dinov2-base', "n_cls": 4, "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"] } } DEFAULT_CKPT = "V1-CAFormer" LOCAL_CKPT_DIR = "./checkpoints" SEED = 4421 DROP_RATE = 0.1 DROPOUT_RATE = 0.1 # From train.py for DINOv2 device = "cuda" if torch.cuda.is_available() else "cpu" torch.manual_seed(SEED); np.random.seed(SEED) print(f"Using device: {device}") model, current_ckpt = None, None current_meta = None # --- Original DINOv2 Classifier (Weighted Attention Pooling) --- class DINOv2Classifier_WeightedPool(nn.Module): def __init__(self, model_name, num_classes): super().__init__() self.backbone = AutoModel.from_pretrained(model_name) self.weight_self_attention = nn.MultiheadAttention( embed_dim=self.backbone.config.hidden_size, num_heads=self.backbone.config.num_attention_heads, dropout=self.backbone.config.hidden_dropout_prob, batch_first=True ) self.weight_mlp = nn.Sequential( nn.Linear(self.backbone.config.hidden_size, self.backbone.config.hidden_size * 4), nn.LayerNorm(self.backbone.config.hidden_size * 4), nn.GELU(), nn.Linear(self.backbone.config.hidden_size * 4, 1) ) self.classifier = nn.Sequential( nn.Dropout(DROPOUT_RATE), nn.Linear(self.backbone.config.hidden_size, self.backbone.config.hidden_size), nn.LayerNorm(self.backbone.config.hidden_size), nn.GELU(), nn.Dropout(DROPOUT_RATE), nn.Linear(self.backbone.config.hidden_size, num_classes) ) nn.init.xavier_uniform_(self.weight_self_attention.in_proj_weight) nn.init.xavier_uniform_(self.weight_self_attention.out_proj.weight) nn.init.constant_(self.weight_self_attention.out_proj.bias, 0) for module in [self.weight_mlp, self.classifier]: if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) nn.init.constant_(module.bias, 0) def forward(self, x): outputs = self.backbone(x) attn_output, _ = self.weight_self_attention( outputs.last_hidden_state, outputs.last_hidden_state, outputs.last_hidden_state, ) raw_weights = self.weight_mlp(attn_output) raw_weights = raw_weights.squeeze(-1) pooling_weights = torch.softmax(raw_weights, dim=-1) pooled_output = torch.sum(outputs.last_hidden_state * pooling_weights.unsqueeze(-1), dim=1) return self.classifier(pooled_output) # --- New DINOv2 Classifier (Mean Pooling) --- class DINOv2Classifier_MeanPool(nn.Module): def __init__(self, model_name, num_classes): super().__init__() self.backbone = AutoModel.from_pretrained(model_name) self.classifier = nn.Sequential( nn.Dropout(DROPOUT_RATE), nn.Linear(self.backbone.config.hidden_size, self.backbone.config.hidden_size), nn.LayerNorm(self.backbone.config.hidden_size), nn.GELU(), nn.Dropout(DROPOUT_RATE), nn.Linear(self.backbone.config.hidden_size, num_classes) ) for module in self.classifier: if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) nn.init.constant_(module.bias, 0) def forward(self, x, return_features=False): outputs = self.backbone(x) pooled_output = outputs.last_hidden_state.mean(dim=1) if return_features: return pooled_output return self.classifier(pooled_output) class SwinClassifier(nn.Module): def __init__(self, model_name, num_classes, pretrained=True, head_version="v4"): super().__init__() self.backbone = timm.create_model( model_name, pretrained=pretrained, num_classes=0 ) self.data_config = timm.data.resolve_data_config({}, model=self.backbone) # ------- 根据版本选择不同 head ------- if head_version == "v7": # <-- V7, V8, V9, V10: 极简 64-hidden, GELU self.classifier = nn.Sequential( nn.Dropout(DROP_RATE), nn.Linear(self.backbone.num_features, 64), nn.BatchNorm1d(64), nn.GELU(), nn.Dropout(DROP_RATE * 0.8), nn.Linear(64, num_classes), ) elif head_version == "v5": # V5: 512-128, GELU self.classifier = nn.Sequential( nn.Dropout(DROP_RATE), nn.Linear(self.backbone.num_features, 512), nn.BatchNorm1d(512), nn.GELU(), nn.Dropout(DROP_RATE * 0.7), nn.Linear(512, 128), nn.BatchNorm1d(128), nn.GELU(), nn.Dropout(DROP_RATE * 0.5), nn.Linear(128, num_classes), ) else: # V2 / V4: 512-128, ReLU self.classifier = nn.Sequential( nn.Dropout(DROP_RATE), nn.Linear(self.backbone.num_features, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(DROP_RATE * 0.7), nn.Linear(512, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(DROP_RATE * 0.5), nn.Linear(128, num_classes), ) def forward(self, x): return self.classifier(self.backbone(x)) # -------------------------------------------------- # 4. 动态加载模型 # -------------------------------------------------- def load_model(ckpt_name: str): """Load model only when `ckpt_name` changes.""" global model, current_ckpt, current_meta if ckpt_name == current_ckpt and model is not None: return print(f"\n🔄 Switching to {ckpt_name} ...") meta = CKPT_META[ckpt_name] ckpt_filename = HF_FILENAMES[ckpt_name] ckpt_file = hf_hub_download( repo_id=REPO_ID, filename=ckpt_filename, local_dir=LOCAL_CKPT_DIR, force_download=False ) print(f"Checkpoint: {ckpt_file}") # Build model structure based on model_type model_type = meta.get("model_type") if model_type == "dinov2_weighted_pool": model = DINOv2Classifier_WeightedPool( model_name=meta["backbone"], num_classes=meta["n_cls"] ).to(device) elif model_type == "dinov2_mean_pool": model = DINOv2Classifier_MeanPool( model_name=meta["backbone"], num_classes=meta["n_cls"] ).to(device) else: # Existing logic for Swin/CAFormer model = SwinClassifier( meta["backbone"], num_classes=meta["n_cls"], pretrained=False, head_version=meta.get("head", "v4") ).to(device) # Compatible load for .pth and .safetensors if ckpt_filename.endswith(".safetensors"): state = load_file(ckpt_file, device=device) else: state = torch.load(ckpt_file, map_location=device, weights_only=False) model.load_state_dict(state.get("model_state_dict", state), strict=True) model.eval() current_ckpt, current_meta = ckpt_name, meta print(f"✅ {ckpt_name} loaded (classes = {meta['n_cls']}).") # -------------------------------------------------- # 5. Transform 工厂 # -------------------------------------------------- def build_transform(is_training: bool, interpolation: str): if model is None: raise RuntimeError("Model not loaded yet.") cfg = model.data_config.copy() cfg.update(dict(interpolation=interpolation)) return timm.data.create_transform(**cfg, is_training=is_training) # -------------------------------------------------- # 6. Inference # -------------------------------------------------- @torch.no_grad() def predict(image: Image.Image, ckpt_name: str, interpolation: str = "bicubic"): if image is None: return None load_model(ckpt_name) # Select transform based on the current model type if "dinov2" in current_meta.get("model_type", ""): # DINOv2 specific transform tfm = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) else: # Original transform logic for Swin/CAFormer tfm = build_transform(False, interpolation) inp = tfm(image).unsqueeze(0).to(device) probs = F.softmax(model(inp), dim=1)[0].cpu() class_names = current_meta["names"] # 保证 gr.Label 在 2 / 4 类都能正常显示 return {class_names[i]: float(probs[i]) for i in range(len(class_names))} # -------------------------------------------------- # 7. Gradio UI # -------------------------------------------------- def launch(): load_model(DEFAULT_CKPT) # 预加载 with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# AI Detector") gr.Markdown( "Choose a model checkpoint on the left, upload an image, " "and click **Run** to see predictions. Checkpoint V7+ and all DINOv2 models output 4 classes." ) with gr.Row(): with gr.Column(scale=1): run_btn = gr.Button("🚀 Run", variant="primary") sel_ckpt = gr.Dropdown( list(HF_FILENAMES.keys()), value=DEFAULT_CKPT, label="Checkpoint" ) sel_interp = gr.Radio( ["bilinear", "bicubic", "nearest"], value="bicubic", label="Resize Interpolation (for Swin/CAFormer)" ) in_img = gr.Image(type="pil", label="Upload Image") with gr.Column(scale=1): # num_top_classes 设为 4,兼容 2-class / 4-class out_lbl = gr.Label(num_top_classes=4, label="Predictions") run_btn.click( predict, inputs=[in_img, sel_ckpt, sel_interp], outputs=[out_lbl] ) # optional example folder if not os.path.exists("examples"): os.makedirs("examples") print("Put some jpg/png files inside ./examples for demo examples") example_files = [os.path.join("examples", f) for f in os.listdir("examples") if f.lower().endswith(('.png', '.jpg', '.jpeg'))] if example_files: gr.Examples( examples=[[f, DEFAULT_CKPT, "bicubic"] for f in example_files], inputs=[in_img, sel_ckpt, sel_interp], outputs=[out_lbl], fn=predict, cache_examples=False, ) demo.launch() # -------------------------------------------------- if __name__ == "__main__": launch()