MNCJihun commited on
Commit
2f77bb4
·
1 Parent(s): ec3b03d

add model path

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -9,9 +9,10 @@ import os
9
  import gradio as gr
10
  import albumentations as A
11
  from albumentations.pytorch import ToTensorV2
12
- import requests
13
- import random
14
 
 
 
15
 
16
  test_transforms = A.Compose(
17
  [
@@ -33,13 +34,14 @@ def predict(img):
33
 
34
  labels = ['no_trunk', 'trunk']
35
 
36
- model = models.resnet50(pretrained=True)
37
  model.fc = nn.Sequential(
38
  nn.Dropout(0.5),
39
  nn.Linear(model.fc.in_features, 2)
40
  )
41
- MODEL_PATH=''
42
- # model.load_state_dict(torch.load(MODEL_PATH))
 
43
  model.eval()
44
 
45
  inputs = gr.inputs.Image()
 
9
  import gradio as gr
10
  import albumentations as A
11
  from albumentations.pytorch import ToTensorV2
12
+ import urllib.request
 
13
 
14
+ MODEL_URL = "https://huggingface.co/caisarl76/HI_motorcycle_trunk_cls_model/resolve/main/best_model.pth"
15
+ MODEL_PATH = "/workspace/result/best_model.pth"
16
 
17
  test_transforms = A.Compose(
18
  [
 
34
 
35
  labels = ['no_trunk', 'trunk']
36
 
37
+ model = models.resnet50(pretrained=False)
38
  model.fc = nn.Sequential(
39
  nn.Dropout(0.5),
40
  nn.Linear(model.fc.in_features, 2)
41
  )
42
+
43
+ urllib.request.urlretrieve(MODEL_URL, MODEL_PATH)
44
+ model.load_state_dict(torch.load(MODEL_PATH))
45
  model.eval()
46
 
47
  inputs = gr.inputs.Image()