yaya36095 commited on
Commit
e375561
·
verified ·
1 Parent(s): 40dd3b2

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +25 -37
handler.py CHANGED
@@ -1,37 +1,25 @@
1
- import torch
2
- from torchvision import transforms
3
- from PIL import Image
4
- from torchvision.models import efficientnet_v2_s
5
-
6
- class EndpointHandler:
7
- def __init__(self, path=""):
8
- self.model = efficientnet_v2_s(weights=None)
9
- self.model.classifier = torch.nn.Sequential(
10
- torch.nn.Linear(self.model.classifier[1].in_features, 1024),
11
- torch.nn.ReLU(),
12
- torch.nn.Dropout(0.3),
13
- torch.nn.Linear(1024, 512),
14
- torch.nn.ReLU(),
15
- torch.nn.Dropout(0.3),
16
- torch.nn.Linear(512, 2)
17
- )
18
- self.model.load_state_dict(torch.load(f"{path}/pytorch_model.bin", map_location=torch.device("cpu")))
19
- self.model.eval()
20
-
21
- self.transform = transforms.Compose([
22
- transforms.Resize((224, 224)),
23
- transforms.ToTensor(),
24
- transforms.Normalize([0.485, 0.456, 0.406],
25
- [0.229, 0.224, 0.225])
26
- ])
27
-
28
- def __call__(self, data):
29
- image = Image.open(data["inputs"]).convert("RGB")
30
- image = self.transform(image).unsqueeze(0)
31
- with torch.no_grad():
32
- outputs = self.model(image)
33
- probs = torch.nn.functional.softmax(outputs[0], dim=0)
34
- return {
35
- "real": round(probs[0].item(), 4),
36
- "fake": round(probs[1].item(), 4)
37
- }
 
1
+ from typing import Dict, Any
2
+ from PIL import Image
3
+ from io import BytesIO
4
+ import base64
5
+
6
+ from inference import load_model, predict
7
+
8
+ # تحميل النموذج عند بدء السيرفر
9
+ model = load_model()
10
+
11
+ # الدالة اللي بيتم استدعاؤها عند رفع صورة
12
+ def predict_image(inputs: Dict[str, Any]) -> Dict[str, float]:
13
+ # لو كانت الصورة مرسلة كـ base64
14
+ if "image" in inputs:
15
+ image_data = inputs["image"]
16
+ if isinstance(image_data, str):
17
+ image = Image.open(BytesIO(base64.b64decode(image_data))).convert("RGB")
18
+ else:
19
+ image = Image.open(image_data).convert("RGB")
20
+ else:
21
+ raise ValueError("Missing 'image' key in input")
22
+
23
+ # التنبؤ
24
+ result = predict(model, image)
25
+ return result