""" PyTorch model implementation for AI image detection """ import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms import torchvision.models as models from PIL import Image import numpy as np import os # Define the model architecture based on EfficientNetV2-S class AIDetectorModel(nn.Module): def __init__(self): super(AIDetectorModel, self).__init__() # Load EfficientNetV2-S as base model self.base_model = models.efficientnet_v2_s(weights=None) # Replace classifier with custom layers self.base_model.classifier = nn.Sequential( nn.Linear(self.base_model.classifier[1].in_features, 1024), nn.ReLU(), nn.Dropout(p=0.3), nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(p=0.3), nn.Linear(512, 2) # 2 classes: real or AI-generated ) def forward(self, x): return self.base_model(x) class PyTorchAIDetector: def __init__(self, model_path='best_model_improved.pth'): """ Initialize the PyTorch-based AI image detector Args: model_path: Path to the trained model file """ # Check if CUDA is available self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") # Initialize the model self.model = AIDetectorModel() # Load the trained weights model_path = os.path.join(os.path.dirname(__file__), model_path) try: # Try to load with strict=True first self.model.load_state_dict(torch.load(model_path, map_location=self.device)) print(f"Model loaded successfully from {model_path}") except Exception as e: print(f"Error with strict loading: {e}") print("Trying with strict=False...") # If that fails, try with strict=False self.model.load_state_dict(torch.load(model_path, map_location=self.device), strict=False) print("Model loaded with strict=False") self.model.to(self.device) self.model.eval() # Set to evaluation mode # Define image transformations - same as used in training self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def analyze_image(self, image_path): """ Analyze an image to detect if it's AI-generated Args: image_path: Path to the image Returns: Dictionary with analysis results """ try: # Load and preprocess the image image = Image.open(image_path).convert('RGB') image_tensor = self.transform(image).unsqueeze(0).to(self.device) # Make prediction with torch.no_grad(): outputs = self.model(image_tensor) probabilities = F.softmax(outputs, dim=1) # Get the probability of being AI-generated (assuming class 1 is AI-generated) ai_score = probabilities[0, 1].item() # Determine if the image is AI-generated is_ai_generated = ai_score > 0.5 # Prepare results results = { "image_path": image_path, "overall_score": float(ai_score), "is_ai_generated": bool(is_ai_generated), "model_type": "pytorch" } return results except Exception as e: raise ValueError(f"Failed to analyze image: {str(e)}")