william1324 commited on
Commit
c8b0511
·
verified ·
1 Parent(s): a86822b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -46
app.py CHANGED
@@ -1,52 +1,17 @@
1
- import torch
2
  import gradio as gr
3
  import numpy as np
4
- from PIL import Image
5
- import torchvision.transforms as T
6
 
7
- # 假設模型是分類用簡單 CNN(你要換成你的模型定義)
8
- class SimpleCNN(torch.nn.Module):
9
- def __init__(self):
10
- super().__init__()
11
- self.conv = torch.nn.Conv2d(3, 16, 3, stride=2)
12
- self.fc = torch.nn.Linear(16*111*111, 2) # 2分類為例
13
-
14
- def forward(self, x):
15
- x = self.conv(x)
16
- x = torch.relu(x)
17
- x = x.view(x.size(0), -1)
18
- x = self.fc(x)
19
- return torch.softmax(x, dim=1)
20
-
21
- # 載入模型與權重
22
- model = SimpleCNN()
23
- model.load_state_dict(torch.load("salg_model.pt", map_location="cpu"))
24
- model.eval()
25
-
26
- # 影像預處理
27
- transform = T.Compose([
28
- T.Resize((224, 224)),
29
- T.ToTensor(),
30
- ])
31
-
32
- def predict(image):
33
- img = Image.fromarray(image.astype('uint8'), 'RGB')
34
- input_tensor = transform(img).unsqueeze(0) # 加 batch 維度
35
- with torch.no_grad():
36
- pred = model(input_tensor)
37
- # 取最高機率標籤
38
- pred_label = torch.argmax(pred, dim=1).item()
39
- confidence = pred[0][pred_label].item()
40
- # 在圖上標示結果(簡單做法)
41
- result_img = image.copy()
42
- import cv2
43
- cv2.putText(result_img, f"Label: {pred_label} Conf: {confidence:.2f}",
44
- (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,255,0), 2)
45
- return result_img, f"Label: {pred_label}, Confidence: {confidence:.2f}"
46
 
 
47
  with gr.Blocks() as demo:
48
- gr.Markdown("## 🦺 Helmet Detector with Model")
49
- gr.Markdown("請上傳圖片,點擊「辨識」由模型進行預測")
50
 
51
  with gr.Row():
52
  with gr.Column():
@@ -55,8 +20,8 @@ with gr.Blocks() as demo:
55
 
56
  with gr.Column():
57
  image_output = gr.Image(type="numpy", label="推論結果")
58
- result_label = gr.Textbox(label="辨識摘要")
59
 
60
- detect_button.click(fn=predict, inputs=image_input, outputs=[image_output, result_label])
61
 
62
  demo.launch()
 
 
1
  import gradio as gr
2
  import numpy as np
3
+ import random
 
4
 
5
+ # 模擬辨識按鈕的回傳邏輯
6
+ def fake_detect(image):
7
+ # 隨機回傳 "yap" 或 "not"
8
+ result = random.choice(["yap", "not"])
9
+ return image, result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # 建立 Gradio 介面
12
  with gr.Blocks() as demo:
13
+ gr.Markdown("## 🦺 Helmet ")
14
+ gr.Markdown("請上傳一張圖片,然後點擊「辨識」按鈕模擬結果展示")
15
 
16
  with gr.Row():
17
  with gr.Column():
 
20
 
21
  with gr.Column():
22
  image_output = gr.Image(type="numpy", label="推論結果")
23
+ result_label = gr.Textbox(label="辨識摘要", placeholder="結果")
24
 
25
+ detect_button.click(fn=fake_detect, inputs=image_input, outputs=[image_output, result_label])
26
 
27
  demo.launch()