yaya36095 commited on
Commit
2ceb92c
·
verified ·
1 Parent(s): b25f77f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +43 -124
handler.py CHANGED
@@ -1,131 +1,50 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import torchvision.models as models
5
- import torchvision.transforms as transforms
6
  from PIL import Image
7
- import json
8
- import base64
9
- from io import BytesIO
10
- from typing import Dict, List, Any, Union
11
-
12
- # Define the model architecture based on EfficientNetV2-S
13
- class AIDetectorModel(nn.Module):
14
- def __init__(self):
15
- super(AIDetectorModel, self).__init__()
16
- # Load EfficientNetV2-S as base model
17
- self.base_model = models.efficientnet_v2_s(weights=None)
18
-
19
- # Replace classifier with custom layers
20
- self.base_model.classifier = nn.Sequential(
21
- nn.Linear(self.base_model.classifier[1].in_features, 1024),
22
- nn.ReLU(),
23
- nn.Dropout(p=0.3),
24
- nn.Linear(1024, 512),
25
- nn.ReLU(),
26
- nn.Dropout(p=0.3),
27
- nn.Linear(512, 2) # 2 classes: real or AI-generated
28
- )
29
-
30
- def forward(self, x):
31
- return self.base_model(x)
32
 
33
- # This is the handler that will be used by the Hugging Face Inference API
34
- class Handler:
35
- def __init__(self):
36
- self.initialized = False
37
- self.model = None
38
- self.device = None
39
- self.transform = None
40
-
41
- def initialize(self, context=None):
42
- """Initialize the handler with model and preprocessing"""
43
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
- print(f"Using device: {self.device}")
45
-
46
- # Initialize the model
47
- self.model = AIDetectorModel()
48
 
49
- # Load the trained weights
50
- model_path = os.path.join(os.getcwd(), "best_model_improved.pth")
 
 
 
 
 
51
  try:
52
- # Try to load with strict=True first
53
- self.model.load_state_dict(torch.load(model_path, map_location=self.device))
54
- print(f"Model loaded successfully from {model_path}")
55
- except Exception as e:
56
- print(f"Error with strict loading: {e}")
57
- print("Trying with strict=False...")
58
- # If that fails, try with strict=False
59
- self.model.load_state_dict(torch.load(model_path, map_location=self.device), strict=False)
60
- print("Model loaded with strict=False")
61
 
62
- self.model.to(self.device)
63
- self.model.eval() # Set to evaluation mode
64
-
65
- # Define image transformations - same as used in training
66
- self.transform = transforms.Compose([
67
- transforms.Resize((224, 224)),
68
- transforms.ToTensor(),
69
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
70
- ])
71
-
72
- self.initialized = True
73
- print("Model initialization complete")
74
- return self
75
-
76
- def preprocess(self, data: Union[Dict, List, str, bytes]) -> torch.Tensor:
77
- """Process input data for model inference"""
78
- images = []
79
-
80
- # Handle different input formats
81
- if isinstance(data, dict):
82
- # API format where data is a dictionary
83
- if "inputs" in data:
84
- data = data["inputs"]
85
- elif "image" in data:
86
- data = data["image"]
87
-
88
- # Convert to list for batch processing
89
- if not isinstance(data, list):
90
- data = [data]
91
-
92
- for item in data:
93
- # Process each item based on type
94
- if isinstance(item, str):
95
- if os.path.isfile(item):
96
- # Local file path
97
- image = Image.open(item).convert('RGB')
98
- elif item.startswith("http"):
99
- # URL
100
- from urllib.request import urlopen
101
- image = Image.open(BytesIO(urlopen(item).read())).convert('RGB')
102
- elif item.startswith("data:image"):
103
- # Base64 image with header
104
- image_data = item.split(",")[1]
105
- image = Image.open(BytesIO(base64.b64decode(image_data))).convert('RGB')
106
- else:
107
- # Assume base64 encoded image
108
- try:
109
- image = Image.open(BytesIO(base64.b64decode(item))).convert('RGB')
110
- except Exception:
111
- # If not base64, try as file path again
112
- image = Image.open(item).convert('RGB')
113
- elif isinstance(item, bytes):
114
- # Raw bytes
115
- image = Image.open(BytesIO(item)).convert('RGB')
116
- elif isinstance(item, Image.Image):
117
- # Already a PIL Image
118
- image = item.convert('RGB')
119
- else:
120
- raise ValueError(f"Unsupported input type: {type(item)}")
121
 
122
- # Apply transformations
123
- image_tensor = self.transform(image)
124
- images.append(image_tensor)
125
-
126
- # Stack tensors for batch processing
127
- batch = torch.stack(images).to(self.device)
128
- return batch
129
-
130
- def inference(self, input_batch: torch.Tensor) -> List[Dict[str, Any]]:
131
- """Run model inference on the input batch"""
 
 
 
 
 
 
 
1
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
 
 
 
 
2
  from PIL import Image
3
+ import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ class EndpointHandler:
6
+ def __init__(self, model_dir):
7
+ self.model = AutoModelForImageClassification.from_pretrained(model_dir)
8
+ self.processor = AutoFeatureExtractor.from_pretrained(model_dir)
 
 
 
 
 
 
9
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ self.model.to(self.device)
 
 
 
11
 
12
+ def __call__(self, data):
13
+ """
14
+ Args:
15
+ data: Image data in binary format
16
+ Returns:
17
+ Prediction result as a dictionary
18
+ """
19
  try:
20
+ # Load and process image
21
+ image = Image.open(data).convert("RGB")
22
+ inputs = self.processor(images=image, return_tensors="pt")
23
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
 
 
 
 
 
24
 
25
+ # Get prediction
26
+ with torch.no_grad():
27
+ outputs = self.model(**inputs)
28
+ logits = outputs.logits
29
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
30
+
31
+ # Get prediction class and confidence
32
+ predicted_class_idx = probabilities.argmax().item()
33
+ confidence = probabilities[0][predicted_class_idx].item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # Get class labels
36
+ id2label = self.model.config.id2label
37
+ predicted_class = id2label[predicted_class_idx]
38
+
39
+ # Return results
40
+ return {
41
+ "predicted_class": predicted_class,
42
+ "confidence": confidence,
43
+ "all_probabilities": {
44
+ id2label[i]: prob.item()
45
+ for i, prob in enumerate(probabilities[0])
46
+ }
47
+ }
48
+
49
+ except Exception as e:
50
+ return {"error": str(e)}