Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import torch.nn.functional as F | |
import torch | |
from torchvision import transforms | |
model = torch.load("./model.pth", map_location=torch.device("cpu")) | |
IMG_SIZE = 224 | |
MASK_LABEL = ["Mask worn properly.", "Mask not worn properly: nose out", "Mask not worn properly: chin and nose out", "Didn't wear mask."] | |
transforms_test = transforms.Compose( | |
[ | |
transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
] | |
) | |
MASK_LABEL = ["Mask worn properly.", "Mask not worn properly: nose out", "Mask not worn properly: chin and nose out", "Didn't wear mask."] | |
def predict_image(image): | |
transformed_tensor = torch.unsqueeze(transforms_test(image), 0) | |
logits = model(transformed_tensor) | |
probability = torch.flatten(F.softmax(logits, dim=1)).detach().cpu().numpy() | |
print(probability) | |
labels = {A: B.item() for A, B in zip(MASK_LABEL, probability)} | |
sorted_labels = dict(sorted(labels.items(), key=lambda item: item[1], reverse=True)) | |
print(sorted_labels) | |
return sorted_labels | |
title = "ViT Mask Detection" | |
description = "<p style='text-align: center'>Gradio demo for ViT-16 Mask Image Classification created by <a href='https://github.com/stevenlimcorn'>Steven Limcorn</a></p>" | |
article = "<p style='text-align: center'>An Application made by stevenlimcorn. Notebook access at: <a href='https://github.com/stevenlimcorn/Mask-Classification'>Mask Classification</a></p>" | |
demo = gr.Interface(predict_image, | |
inputs=gr.Image(label="Input Image", type="pil", source="webcam"), | |
outputs=gr.Label(), title=title, description=description, article=article | |
) | |
demo.launch() | |