nemozajung commited on
Commit
e9a573f
·
verified ·
1 Parent(s): e5033a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -10
app.py CHANGED
@@ -3,27 +3,41 @@ from PIL import Image
3
  import gradio as gr
4
  from transformers import AutoImageProcessor, ResNetForImageClassification
5
 
6
- # โหลดโมเดลและ image processor
7
  model_name = "microsoft/resnet-50"
8
  model = ResNetForImageClassification.from_pretrained(model_name)
9
  processor = AutoImageProcessor.from_pretrained(model_name)
10
 
11
- # ฟังก์ชันรับภาพและคืน label ที่โมเดลทำนาย
12
- def classify_image(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  inputs = processor(images=image, return_tensors="pt")
14
  with torch.no_grad():
15
  logits = model(**inputs).logits
16
- predicted_class_idx = logits.argmax(-1).item()
17
- label = model.config.id2label[predicted_class_idx]
18
- return label
19
 
20
- # UI ด้วย Gradio
21
  demo = gr.Interface(
22
- fn=classify_image,
23
  inputs=gr.Image(type="pil"),
24
  outputs="text",
25
- title="ResNet-50 Image Classifier",
26
- description="จำแนกรูปภาพด้วย ResNet-50 ที่ฝึกด้วย ImageNet"
27
  )
28
 
29
  demo.launch()
 
3
  import gradio as gr
4
  from transformers import AutoImageProcessor, ResNetForImageClassification
5
 
6
+ # โหลดโมเดล ResNet-50
7
  model_name = "microsoft/resnet-50"
8
  model = ResNetForImageClassification.from_pretrained(model_name)
9
  processor = AutoImageProcessor.from_pretrained(model_name)
10
 
11
+ # จำลอง label BMI category
12
+ bmi_labels = ["Underweight", "Normal", "Overweight", "Obese"]
13
+
14
+ # ตัวอย่าง logic จำแนกเบื้องต้นโดย mapping จาก class id (mock-up logic)
15
+ def map_to_bmi(class_id):
16
+ if class_id < 250: # สุ่ม logic สำหรับ demo
17
+ return "Underweight"
18
+ elif class_id < 500:
19
+ return "Normal"
20
+ elif class_id < 750:
21
+ return "Overweight"
22
+ else:
23
+ return "Obese"
24
+
25
+ # ฟังก์ชันหลัก
26
+ def predict_bmi(image):
27
  inputs = processor(images=image, return_tensors="pt")
28
  with torch.no_grad():
29
  logits = model(**inputs).logits
30
+ class_id = logits.argmax(-1).item()
31
+ bmi_category = map_to_bmi(class_id)
32
+ return f"Estimated Body Type (BMI category): {bmi_category}"
33
 
34
+ # Gradio UI
35
  demo = gr.Interface(
36
+ fn=predict_bmi,
37
  inputs=gr.Image(type="pil"),
38
  outputs="text",
39
+ title="BMI Estimator (Demo with ResNet-50)",
40
+ description="จำแนกรูปร่างตามลักษณะ BMI ด้วย ResNet-50 (Demo Simulation)"
41
  )
42
 
43
  demo.launch()