yaya36095 commited on
Commit
1999a06
·
verified ·
1 Parent(s): 80f8076

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +74 -29
handler.py CHANGED
@@ -1,50 +1,95 @@
1
- from transformers import AutoFeatureExtractor, AutoModelForImageClassification
2
  from PIL import Image
3
  import torch
 
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, model_dir):
7
- self.model = AutoModelForImageClassification.from_pretrained(model_dir)
8
- self.processor = AutoFeatureExtractor.from_pretrained(model_dir)
9
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
- self.model.to(self.device)
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def __call__(self, data):
13
  """
14
- Args:
15
- data: Image data in binary format
16
- Returns:
17
- Prediction result as a dictionary
 
 
 
18
  """
19
  try:
20
- # Load and process image
21
- image = Image.open(data).convert("RGB")
22
- inputs = self.processor(images=image, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
24
 
25
- # Get prediction
26
  with torch.no_grad():
27
  outputs = self.model(**inputs)
28
  logits = outputs.logits
29
  probabilities = torch.nn.functional.softmax(logits, dim=-1)
30
 
31
- # Get prediction class and confidence
32
- predicted_class_idx = probabilities.argmax().item()
33
- confidence = probabilities[0][predicted_class_idx].item()
34
-
35
- # Get class labels
36
- id2label = self.model.config.id2label
37
- predicted_class = id2label[predicted_class_idx]
38
-
39
- # Return results
40
- return {
41
- "predicted_class": predicted_class,
42
- "confidence": confidence,
43
- "all_probabilities": {
44
- id2label[i]: prob.item()
45
- for i, prob in enumerate(probabilities[0])
 
46
  }
47
  }
48
 
 
 
49
  except Exception as e:
50
- return {"error": str(e)}
 
 
 
 
1
+ from transformers import ViTFeatureExtractor, ViTForImageClassification
2
  from PIL import Image
3
  import torch
4
+ import base64
5
+ import io
6
 
7
  class EndpointHandler:
8
  def __init__(self, model_dir):
9
+ """
10
+ تهيئة النموذج ومعالج الميزات
 
 
11
 
12
+ المعلمات:
13
+ model_dir: مسار مجلد النموذج
14
+ """
15
+ try:
16
+ # تحميل النموذج ومعالج الميزات
17
+ self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir)
18
+ self.model = ViTForImageClassification.from_pretrained(model_dir)
19
+
20
+ # نقل النموذج إلى وحدة المعالجة المناسبة
21
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ self.model.to(self.device)
23
+
24
+ # وضع النموذج في وضع التقييم
25
+ self.model.eval()
26
+
27
+ print(f"تم تحميل النموذج بنجاح على جهاز {self.device}")
28
+ except Exception as e:
29
+ print(f"خطأ في تحميل النموذج: {str(e)}")
30
+ raise
31
+
32
  def __call__(self, data):
33
  """
34
+ معالجة البيانات المدخلة وإجراء التنبؤ
35
+
36
+ المعلمات:
37
+ data: بيانات الصورة المشفرة بـ base64 أو كائن الصورة
38
+
39
+ العائد:
40
+ dict: نتائج التنبؤ
41
  """
42
  try:
43
+ # التحقق من نوع البيانات المدخلة
44
+ if isinstance(data, dict) and "image" in data:
45
+ # إذا كانت البيانات مشفرة بـ base64
46
+ image_data = data["image"]
47
+ if isinstance(image_data, str) and image_data.startswith("data:image"):
48
+ # إزالة بادئة URL للبيانات
49
+ image_data = image_data.split(",")[1]
50
+
51
+ # فك تشفير البيانات
52
+ image_bytes = base64.b64decode(image_data)
53
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
54
+ elif isinstance(data, bytes):
55
+ # إذا كانت البيانات ثنائية
56
+ image = Image.open(io.BytesIO(data)).convert("RGB")
57
+ else:
58
+ return {"error": "تنسيق البيانات غير مدعوم"}
59
+
60
+ # معالجة الصورة
61
+ inputs = self.feature_extractor(images=image, return_tensors="pt")
62
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
63
 
64
+ # إجراء التنبؤ
65
  with torch.no_grad():
66
  outputs = self.model(**inputs)
67
  logits = outputs.logits
68
  probabilities = torch.nn.functional.softmax(logits, dim=-1)
69
 
70
+ # الحصول على التصنيف ونسبة الثقة
71
+ predicted_class_idx = probabilities.argmax().item()
72
+ confidence = probabilities[0][predicted_class_idx].item()
73
+
74
+ # تحويل الفهرس إلى تسمية
75
+ id2label = self.model.config.id2label
76
+ predicted_class = id2label[predicted_class_idx]
77
+
78
+ # إعداد النتائج
79
+ results = {
80
+ "prediction": predicted_class,
81
+ "confidence": float(confidence),
82
+ "is_fake": predicted_class == "fake",
83
+ "probabilities": {
84
+ label: float(prob)
85
+ for label, prob in zip(id2label.values(), probabilities[0].cpu().numpy())
86
  }
87
  }
88
 
89
+ return results
90
+
91
  except Exception as e:
92
+ # معالجة الأخطاء
93
+ error_message = str(e)
94
+ print(f"خطأ في معالجة الصورة: {error_message}")
95
+ return {"error": error_message}