deepfake-detector / pytorch_model.py
yaya36095's picture
Update pytorch_model.py
95f585b verified
"""
PyTorch model implementation for AI image detection
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import numpy as np
import os
# Define the model architecture based on EfficientNetV2-S
class AIDetectorModel(nn.Module):
def __init__(self):
super(AIDetectorModel, self).__init__()
# Load EfficientNetV2-S as base model
self.base_model = models.efficientnet_v2_s(weights=None)
# Replace classifier with custom layers
self.base_model.classifier = nn.Sequential(
nn.Linear(self.base_model.classifier[1].in_features, 1024),
nn.ReLU(),
nn.Dropout(p=0.3),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(p=0.3),
nn.Linear(512, 2) # 2 classes: real or AI-generated
)
def forward(self, x):
return self.base_model(x)
class PyTorchAIDetector:
def __init__(self, model_path='best_model_improved.pth'):
"""
Initialize the PyTorch-based AI image detector
Args:
model_path: Path to the trained model file
"""
# Check if CUDA is available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
# Initialize the model
self.model = AIDetectorModel()
# Load the trained weights
model_path = os.path.join(os.path.dirname(__file__), model_path)
try:
# Try to load with strict=True first
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
print(f"Model loaded successfully from {model_path}")
except Exception as e:
print(f"Error with strict loading: {e}")
print("Trying with strict=False...")
# If that fails, try with strict=False
self.model.load_state_dict(torch.load(model_path, map_location=self.device), strict=False)
print("Model loaded with strict=False")
self.model.to(self.device)
self.model.eval() # Set to evaluation mode
# Define image transformations - same as used in training
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def analyze_image(self, image_path):
"""
Analyze an image to detect if it's AI-generated
Args:
image_path: Path to the image
Returns:
Dictionary with analysis results
"""
try:
# Load and preprocess the image
image = Image.open(image_path).convert('RGB')
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
# Make prediction
with torch.no_grad():
outputs = self.model(image_tensor)
probabilities = F.softmax(outputs, dim=1)
# Get the probability of being AI-generated (assuming class 1 is AI-generated)
ai_score = probabilities[0, 1].item()
# Determine if the image is AI-generated
is_ai_generated = ai_score > 0.5
# Prepare results
results = {
"image_path": image_path,
"overall_score": float(ai_score),
"is_ai_generated": bool(is_ai_generated),
"model_type": "pytorch"
}
return results
except Exception as e:
raise ValueError(f"Failed to analyze image: {str(e)}")