Spaces:
Sleeping
Sleeping
import torch | |
import torchvision | |
import gradio as gr | |
from model import create_vit_b16 | |
# loading the classnames | |
class_names = [] | |
with open("/content/Birds_Classification/class_names.txt", "r") as f: | |
for cls in f.readlines(): | |
class_names.append(cls[:-1]) | |
# creating a ViT model | |
vit_b16_model, vit_transforms = create_vit_b16(num_classes=len(class_names)) | |
vit_b16_model.load_state_dict( | |
torch.load( | |
f="/content/Birds_Classification/ViT_B16_510_classes.pth", | |
map_location=torch.device("cpu") | |
) | |
) | |
# creating a predict function | |
def predict(img): | |
""" | |
Makes prediction on the given image | |
""" | |
img = vit_transforms(img).unsqueeze(dim=0) | |
vit_b16_model.eval() | |
with torch.inference_mode(): | |
pred_logits = vit_b16_model(img) | |
preds = torch.softmax(pred_logits, dim=1) | |
# Create a prediction label and prediction probability dictionary for each prediction class | |
pred_and_prob_labels = {class_names[i]: preds[i]*100 for i in range(len(class_names))} | |
return pred_and_prob_labels | |
# creating title, description for the webpage | |
title = "Birds Classifier 🪶" | |
description = "Classifies an Image of a Bird to any one of the 510 species." | |
# creating examples list | |
examples_list = [["examples/" + img] for img in os.listdir("examples")] | |
# Building a gradio app | |
bird_classification = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Label(num_top_classes=3, label="Prediction"), | |
examples = examples_list, | |
title=title, | |
description=description | |
) | |
# launching the web app | |
bird_classification.launch() | |