import os import torch import torch.nn as nn import torchvision.models as models from transformers import PreTrainedModel, AutoConfig # Define the model architecture based on EfficientNetV2-S class AIDetectorModel(nn.Module): def __init__(self): super(AIDetectorModel, self).__init__() # Load EfficientNetV2-S as base model self.base_model = models.efficientnet_v2_s(weights=None) # Replace classifier with custom layers self.base_model.classifier = nn.Sequential( nn.Linear(self.base_model.classifier[1].in_features, 1024), nn.ReLU(), nn.Dropout(p=0.3), nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(p=0.3), nn.Linear(512, 2) # 2 classes: real or AI-generated ) def forward(self, x): return self.base_model(x) # Wrapper class to make the model compatible with Hugging Face class AIDetectorForImageClassification(PreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = AIDetectorModel() # Load the trained weights model_path = os.path.join(os.getcwd(), "best_model_improved.pth") try: # Try to load with strict=True first self.model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) print(f"Model loaded successfully from {model_path}") except Exception as e: print(f"Error with strict loading: {e}") print("Trying with strict=False...") # If that fails, try with strict=False self.model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")), strict=False) print("Model loaded with strict=False") def forward(self, pixel_values, labels=None, **kwargs): logits = self.model(pixel_values) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits} # Function to create and load the model def get_model(): config = AutoConfig.from_pretrained("./") model = AIDetectorForImageClassification(config) return model