File size: 3,915 Bytes
eda917f
 
 
 
 
 
 
 
95f585b
eda917f
 
 
 
95f585b
eda917f
 
 
95f585b
 
 
 
 
 
 
 
 
 
 
 
 
eda917f
 
95f585b
eda917f
 
 
 
 
 
 
 
 
 
 
95f585b
eda917f
 
 
 
 
 
95f585b
 
 
 
 
 
 
 
 
 
 
eda917f
 
 
95f585b
eda917f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
PyTorch model implementation for AI image detection
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import numpy as np
import os

# Define the model architecture based on EfficientNetV2-S
class AIDetectorModel(nn.Module):
    def __init__(self):
        super(AIDetectorModel, self).__init__()
        # Load EfficientNetV2-S as base model
        self.base_model = models.efficientnet_v2_s(weights=None)
        
        # Replace classifier with custom layers
        self.base_model.classifier = nn.Sequential(
            nn.Linear(self.base_model.classifier[1].in_features, 1024),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(512, 2)  # 2 classes: real or AI-generated
        )
    
    def forward(self, x):
        return self.base_model(x)

class PyTorchAIDetector:
    def __init__(self, model_path='best_model_improved.pth'):
        """
        Initialize the PyTorch-based AI image detector
        
        Args:
            model_path: Path to the trained model file
        """
        # Check if CUDA is available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        
        # Initialize the model
        self.model = AIDetectorModel()
        
        # Load the trained weights
        model_path = os.path.join(os.path.dirname(__file__), model_path)
        try:
            # Try to load with strict=True first
            self.model.load_state_dict(torch.load(model_path, map_location=self.device))
            print(f"Model loaded successfully from {model_path}")
        except Exception as e:
            print(f"Error with strict loading: {e}")
            print("Trying with strict=False...")
            # If that fails, try with strict=False
            self.model.load_state_dict(torch.load(model_path, map_location=self.device), strict=False)
            print("Model loaded with strict=False")
            
        self.model.to(self.device)
        self.model.eval()  # Set to evaluation mode
        
        # Define image transformations - same as used in training
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def analyze_image(self, image_path):
        """
        Analyze an image to detect if it's AI-generated
        
        Args:
            image_path: Path to the image
            
        Returns:
            Dictionary with analysis results
        """
        try:
            # Load and preprocess the image
            image = Image.open(image_path).convert('RGB')
            image_tensor = self.transform(image).unsqueeze(0).to(self.device)
            
            # Make prediction
            with torch.no_grad():
                outputs = self.model(image_tensor)
                probabilities = F.softmax(outputs, dim=1)
                
                # Get the probability of being AI-generated (assuming class 1 is AI-generated)
                ai_score = probabilities[0, 1].item()
                
                # Determine if the image is AI-generated
                is_ai_generated = ai_score > 0.5
            
            # Prepare results
            results = {
                "image_path": image_path,
                "overall_score": float(ai_score),
                "is_ai_generated": bool(is_ai_generated),
                "model_type": "pytorch"
            }
            
            return results
            
        except Exception as e:
            raise ValueError(f"Failed to analyze image: {str(e)}")