KhadijaAsehnoune12's picture
Update app.py
82e3bb8 verified
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import numpy as np
import rembg
# Define the model and feature extractor
model_name ="KhadijaAsehnoune12/ViTOrangeLeafDiseaseClassifier"
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
# Define the label mapping
id2label = {
"0": "Aleurocanthus spiniferus",
"1": "Chancre citrique",
"2": "Cochenille blanche",
"3": "Dépérissement des agrumes",
"4": "Feuille saine",
"5": "Jaunissement des feuilles",
"6": "Maladie de l'oïdium",
"7": "Maladie du dragon jaune",
"8": "Mineuse des agrumes",
"9": "Trou de balle"
}
def remove_background(image):
image = image.convert("RGBA")
image_np = np.array(image)
output_np = rembg.remove(image_np)
white_bg = Image.new("RGBA", image.size, "WHITE")
output_image = Image.alpha_composite(white_bg, Image.fromarray(output_np))
output_image = output_image.convert("RGB")
return output_image
def predict(image):
image = remove_background(image)
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
predicted_class_idx = probs.argmax().item()
predicted_label = id2label[str(predicted_class_idx)]
confidence_score = probs[predicted_class_idx].item() * 100
return f"{predicted_label}: {confidence_score:.2f}%"
# Create the Gradio interface
image = gr.Image(type="pil")
label = gr.Textbox(label="Prediction")
gr.Interface(fn=predict,
inputs=image,
outputs=label,
title="Classification des maladies des agrumes",
description="Téléchargez une image d'une feuille d'agrume pour classer sa maladie. Le modèle est entraîné sur les maladies suivantes : Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.",
examples=["maladie_du_dragon_jaune.jpg", "feuille_saine.jpg"]).launch(share=True)