Spaces:
Runtime error
Runtime error
File size: 1,868 Bytes
4fce033 e4dfec0 fca27e0 547f1df fca27e0 2f77bb4 547f1df 7041714 a7c1b55 cf991f8 e4dfec0 2a50db8 98d3545 e4dfec0 fca27e0 547f1df fca27e0 547f1df 2f77bb4 fca27e0 2f77bb4 582ad59 67e9c3a fca27e0 512a750 fca27e0 e4dfec0 547f1df fca27e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import os
import gradio as gr
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import urllib.request
test_transforms = A.Compose(
[
A.SmallestMaxSize(max_size=350),
A.CenterCrop(height=256, width=256),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
]
)
print(os.getcwd())
print(os.listdir(os.getcwd()))
img_samples = os.listdir('./sample/')
# print(os.path.isfile(os.path.join(os.getcwd(),'sample', img_samples[0])))
# assert False
img_samples = [os.path.join(os.getcwd(), './sample/', img) for img in img_samples]
MODEL_URL = "https://huggingface.co/caisarl76/HI_motorcycle_trunk_cls_model/resolve/main/best_model.pth"
MODEL_PATH = "/tmp/best_model.pth"
urllib.request.urlretrieve(MODEL_URL, MODEL_PATH)
def predict(img):
img = Image.fromarray(img.astype('uint8'), 'RGB')
img = transforms.ToTensor()(img).unsqueeze(0)
with torch.no_grad():
pred = torch.nn.functional.softmax(model(img)[0], dim=0)
return {labels[i]: float(pred[i]) for i in range(2)}
labels = ['no_trunk', 'trunk']
model = models.resnet50(pretrained=False)
model.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(model.fc.in_features, 2)
)
device = torch.device('cpu')
try:
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
print('model load complete')
except:
print('CANNOT load model weight')
model.eval()
for _, p in model.named_parameters():
p.requires_grad = False
inputs = gr.inputs.Image()
outputs = gr.outputs.Label(num_top_classes=1)
gr.Interface(fn=predict,
inputs=inputs,
outputs=outputs,
examples=img_samples).launch()
|