yaya36095 commited on
Commit
ce52ff3
·
verified ·
1 Parent(s): 205e3be

Upload 2 files

Browse files
Files changed (2) hide show
  1. model.py +62 -0
  2. preprocessor_config.json +10 -0
model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.models as models
5
+ from transformers import PreTrainedModel, AutoConfig
6
+
7
+ # Define the model architecture based on EfficientNetV2-S
8
+ class AIDetectorModel(nn.Module):
9
+ def __init__(self):
10
+ super(AIDetectorModel, self).__init__()
11
+ # Load EfficientNetV2-S as base model
12
+ self.base_model = models.efficientnet_v2_s(weights=None)
13
+
14
+ # Replace classifier with custom layers
15
+ self.base_model.classifier = nn.Sequential(
16
+ nn.Linear(self.base_model.classifier[1].in_features, 1024),
17
+ nn.ReLU(),
18
+ nn.Dropout(p=0.3),
19
+ nn.Linear(1024, 512),
20
+ nn.ReLU(),
21
+ nn.Dropout(p=0.3),
22
+ nn.Linear(512, 2) # 2 classes: real or AI-generated
23
+ )
24
+
25
+ def forward(self, x):
26
+ return self.base_model(x)
27
+
28
+ # Wrapper class to make the model compatible with Hugging Face
29
+ class AIDetectorForImageClassification(PreTrainedModel):
30
+ def __init__(self, config):
31
+ super().__init__(config)
32
+ self.num_labels = config.num_labels
33
+ self.model = AIDetectorModel()
34
+
35
+ # Load the trained weights
36
+ model_path = os.path.join(os.getcwd(), "best_model_improved.pth")
37
+ try:
38
+ # Try to load with strict=True first
39
+ self.model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
40
+ print(f"Model loaded successfully from {model_path}")
41
+ except Exception as e:
42
+ print(f"Error with strict loading: {e}")
43
+ print("Trying with strict=False...")
44
+ # If that fails, try with strict=False
45
+ self.model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")), strict=False)
46
+ print("Model loaded with strict=False")
47
+
48
+ def forward(self, pixel_values, labels=None, **kwargs):
49
+ logits = self.model(pixel_values)
50
+
51
+ loss = None
52
+ if labels is not None:
53
+ loss_fct = nn.CrossEntropyLoss()
54
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
55
+
56
+ return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
57
+
58
+ # Function to create and load the model
59
+ def get_model():
60
+ config = AutoConfig.from_pretrained("./")
61
+ model = AIDetectorForImageClassification(config)
62
+ return model
preprocessor_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "do_resize": true,
4
+ "image_mean": [0.485, 0.456, 0.406],
5
+ "image_std": [0.229, 0.224, 0.225],
6
+ "size": {
7
+ "height": 224,
8
+ "width": 224
9
+ }
10
+ }