File size: 1,868 Bytes
4fce033
e4dfec0
 
 
fca27e0
547f1df
 
fca27e0
 
2f77bb4
547f1df
 
 
 
 
 
 
 
 
7041714
a7c1b55
cf991f8
e4dfec0
2a50db8
98d3545
 
e4dfec0
 
 
 
 
fca27e0
 
 
 
 
 
547f1df
fca27e0
547f1df
2f77bb4
fca27e0
 
 
 
2f77bb4
582ad59
67e9c3a
 
 
 
 
 
 
fca27e0
512a750
 
fca27e0
 
 
 
 
e4dfec0
 
547f1df
 
fca27e0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import gradio as gr
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import urllib.request

test_transforms = A.Compose(
    [
        A.SmallestMaxSize(max_size=350),
        A.CenterCrop(height=256, width=256),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)

print(os.getcwd())
print(os.listdir(os.getcwd()))
img_samples = os.listdir('./sample/')
# print(os.path.isfile(os.path.join(os.getcwd(),'sample', img_samples[0])))
# assert False
img_samples = [os.path.join(os.getcwd(), './sample/', img) for img in img_samples]
MODEL_URL = "https://huggingface.co/caisarl76/HI_motorcycle_trunk_cls_model/resolve/main/best_model.pth"
MODEL_PATH = "/tmp/best_model.pth" 
urllib.request.urlretrieve(MODEL_URL, MODEL_PATH)


def predict(img):
    img = Image.fromarray(img.astype('uint8'), 'RGB')
    img = transforms.ToTensor()(img).unsqueeze(0)
    with torch.no_grad():
        pred = torch.nn.functional.softmax(model(img)[0], dim=0)
    return {labels[i]: float(pred[i]) for i in range(2)}

labels = ['no_trunk', 'trunk']

model = models.resnet50(pretrained=False)
model.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(model.fc.in_features, 2)
)

device = torch.device('cpu')

try:
    model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
    print('model load complete')
except:
    print('CANNOT load model weight')
    
model.eval()
for _, p in model.named_parameters():
    p.requires_grad = False

inputs = gr.inputs.Image()
outputs = gr.outputs.Label(num_top_classes=1)
gr.Interface(fn=predict,
             inputs=inputs,
             outputs=outputs,
             examples=img_samples).launch()