|
""" |
|
PyTorch model implementation for AI image detection |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.transforms as transforms |
|
import torchvision.models as models |
|
from PIL import Image |
|
import numpy as np |
|
import os |
|
|
|
|
|
class AIDetectorModel(nn.Module): |
|
def __init__(self): |
|
super(AIDetectorModel, self).__init__() |
|
|
|
self.base_model = models.efficientnet_v2_s(weights=None) |
|
|
|
|
|
self.base_model.classifier = nn.Sequential( |
|
nn.Linear(self.base_model.classifier[1].in_features, 1024), |
|
nn.ReLU(), |
|
nn.Dropout(p=0.3), |
|
nn.Linear(1024, 512), |
|
nn.ReLU(), |
|
nn.Dropout(p=0.3), |
|
nn.Linear(512, 2) |
|
) |
|
|
|
def forward(self, x): |
|
return self.base_model(x) |
|
|
|
class PyTorchAIDetector: |
|
def __init__(self, model_path='best_model_improved.pth'): |
|
""" |
|
Initialize the PyTorch-based AI image detector |
|
|
|
Args: |
|
model_path: Path to the trained model file |
|
""" |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {self.device}") |
|
|
|
|
|
self.model = AIDetectorModel() |
|
|
|
|
|
model_path = os.path.join(os.path.dirname(__file__), model_path) |
|
try: |
|
|
|
self.model.load_state_dict(torch.load(model_path, map_location=self.device)) |
|
print(f"Model loaded successfully from {model_path}") |
|
except Exception as e: |
|
print(f"Error with strict loading: {e}") |
|
print("Trying with strict=False...") |
|
|
|
self.model.load_state_dict(torch.load(model_path, map_location=self.device), strict=False) |
|
print("Model loaded with strict=False") |
|
|
|
self.model.to(self.device) |
|
self.model.eval() |
|
|
|
|
|
self.transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
def analyze_image(self, image_path): |
|
""" |
|
Analyze an image to detect if it's AI-generated |
|
|
|
Args: |
|
image_path: Path to the image |
|
|
|
Returns: |
|
Dictionary with analysis results |
|
""" |
|
try: |
|
|
|
image = Image.open(image_path).convert('RGB') |
|
image_tensor = self.transform(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model(image_tensor) |
|
probabilities = F.softmax(outputs, dim=1) |
|
|
|
|
|
ai_score = probabilities[0, 1].item() |
|
|
|
|
|
is_ai_generated = ai_score > 0.5 |
|
|
|
|
|
results = { |
|
"image_path": image_path, |
|
"overall_score": float(ai_score), |
|
"is_ai_generated": bool(is_ai_generated), |
|
"model_type": "pytorch" |
|
} |
|
|
|
return results |
|
|
|
except Exception as e: |
|
raise ValueError(f"Failed to analyze image: {str(e)}") |
|
|