File size: 846 Bytes
b4031e4
 
 
 
68434ba
b4031e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import gradio as gr
from transformers import AutoModelForImageClassification, AutoFeatureExtractor

model_id = "chanelcolgate/vit-base-patch16-224-finetuned-flower"
labels = ["daisy", "dandelion", "roses", "sunflowers", "tulips"]


def classify_image(image):
    model = AutoModelForImageClassification.from_pretrained(model_id)
    feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
    inp = feature_extractor(image, return_tensors="pt")
    outp = model(**inp)
    pred = torch.nn.functional.softmax(outp.logits, dim=-1)
    preds = pred[0].cpu().detach().numpy()
    confidence = {label: float(preds[i]) for i, label in enumerate(labels)}
    return confidence


interface = gr.Interface(
    fn=classify_image,
    inputs="image",
    examples=["flower-1.jpeg", "flower-2.jpeg"],
    outputs="label",
).launch()