|
import torch |
|
import gradio as gr |
|
from PIL import Image |
|
import os |
|
from transformers import AutoImageProcessor, AutoModelForImageClassification |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score |
|
from sklearn.metrics import confusion_matrix |
|
from urllib.parse import quote |
|
import requests |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def load_model(): |
|
processor = AutoImageProcessor.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification") |
|
model = AutoModelForImageClassification.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification").to(device) |
|
model.eval() |
|
return processor, model |
|
|
|
|
|
processor, model = load_model() |
|
|
|
|
|
class_names = ['Healthy', 'Disease_1', 'Disease_2', 'Disease_3'] |
|
|
|
|
|
def predict(image): |
|
if isinstance(image, str): |
|
image = Image.open(image).convert("RGB") |
|
|
|
|
|
inputs = processor(images=image, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
predicted_class_idx = logits.argmax(-1).item() |
|
predicted_class = class_names[predicted_class_idx] |
|
|
|
return predicted_class |
|
|
|
|
|
def get_detailed_wikipedia_info(query): |
|
try: |
|
|
|
search_url = f"https://en.wikipedia.org/w/api.php?action=query&list=search&srsearch={quote(query)}&format=json" |
|
search_response = requests.get(search_url).json() |
|
search_results = search_response.get("query", {}).get("search", []) |
|
|
|
if not search_results: |
|
return query, "No Wikipedia pages found." |
|
|
|
|
|
top_title = search_results[0]["title"] |
|
|
|
|
|
extract_url = f"https://en.wikipedia.org/w/api.php?action=query&prop=extracts&exintro=&explaintext=&titles={quote(top_title)}&format=json" |
|
extract_response = requests.get(extract_url).json() |
|
|
|
page_data = extract_response.get("query", {}).get("pages", {}) |
|
page = next(iter(page_data.values())) |
|
extract_text = page.get("extract", "No detailed info found.") |
|
|
|
|
|
page_link = f"https://en.wikipedia.org/wiki/{quote(top_title.replace(' ', '_'))}" |
|
info = ( |
|
f"### πΏ **{top_title}**\n\n" |
|
f"{extract_text.strip()}\n\n" |
|
f"π [Click here to read more on Wikipedia]({page_link})" |
|
) |
|
|
|
return top_title, info |
|
|
|
except Exception as e: |
|
return query, f"Error fetching info: {str(e)}" |
|
|
|
|
|
def wrapped_predict(image): |
|
|
|
disease_name = predict(image) |
|
|
|
|
|
disease_api_name, disease_info = get_detailed_wikipedia_info(disease_name) |
|
|
|
return disease_name, disease_info |
|
|
|
|
|
def launch_web(): |
|
interface = gr.Interface( |
|
fn=wrapped_predict, |
|
inputs=gr.Image(type="pil", label="Upload Leaf Image"), |
|
outputs=[ |
|
gr.Label(label="Predicted Disease"), |
|
gr.Textbox(label="Disease Information") |
|
], |
|
title="Plant Disease Detection", |
|
description="Upload a leaf image to detect the disease." |
|
) |
|
interface.launch(share=True, debug=True) |
|
|
|
if __name__ == "__main__": |
|
launch_web() |
|
|