File size: 3,995 Bytes
0ed38bf ef05174 f81dce2 ef05174 f81dce2 a246748 0ed38bf a246748 0ed38bf a246748 f81dce2 0ed38bf f81dce2 a246748 f81dce2 0ed38bf a246748 f81dce2 a246748 f81dce2 ef05174 f81dce2 a246748 0ed38bf f81dce2 a246748 0ed38bf f81dce2 ef05174 a246748 f81dce2 ef05174 f81dce2 ef05174 f81dce2 ef05174 f81dce2 ef05174 a246748 ef05174 f81dce2 ef05174 f81dce2 ef05174 f81dce2 ef05174 f81dce2 ef05174 0ed38bf ef05174 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import torch
import gradio as gr
from PIL import Image
import os
from transformers import AutoImageProcessor, AutoModelForImageClassification
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import confusion_matrix
from urllib.parse import quote
import requests
# Ensure the model is loaded only once in the deployment environment
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model loading function
def load_model():
processor = AutoImageProcessor.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification")
model = AutoModelForImageClassification.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification").to(device)
model.eval() # Set the model to evaluation mode
return processor, model
# Load model and processor
processor, model = load_model()
# Define the class names - Update with your actual class names
class_names = ['Healthy', 'Disease_1', 'Disease_2', 'Disease_3'] # Update with your actual classes
# Prediction function
def predict(image):
if isinstance(image, str):
image = Image.open(image).convert("RGB")
# Preprocess the image
inputs = processor(images=image, return_tensors="pt").to(device)
# Inference
with torch.no_grad():
logits = model(**inputs).logits
predicted_class_idx = logits.argmax(-1).item()
predicted_class = class_names[predicted_class_idx]
return predicted_class
# Function to get detailed Wikipedia info
def get_detailed_wikipedia_info(query):
try:
# Step 1: Search for the article
search_url = f"https://en.wikipedia.org/w/api.php?action=query&list=search&srsearch={quote(query)}&format=json"
search_response = requests.get(search_url).json()
search_results = search_response.get("query", {}).get("search", [])
if not search_results:
return query, "No Wikipedia pages found."
# Get the top result title
top_title = search_results[0]["title"]
# Step 2: Get full page content (plaintext extract)
extract_url = f"https://en.wikipedia.org/w/api.php?action=query&prop=extracts&exintro=&explaintext=&titles={quote(top_title)}&format=json"
extract_response = requests.get(extract_url).json()
page_data = extract_response.get("query", {}).get("pages", {})
page = next(iter(page_data.values()))
extract_text = page.get("extract", "No detailed info found.")
# Step 3: Format output with Markdown link
page_link = f"https://en.wikipedia.org/wiki/{quote(top_title.replace(' ', '_'))}"
info = (
f"### 🌿 **{top_title}**\n\n"
f"{extract_text.strip()}\n\n"
f"🔗 [Click here to read more on Wikipedia]({page_link})"
)
return top_title, info
except Exception as e:
return query, f"Error fetching info: {str(e)}"
# Wrapped predict function
def wrapped_predict(image):
# Step 1: Use the model to predict the disease name
disease_name = predict(image)
# Step 2: Use the disease name to query detailed Wikipedia info
disease_api_name, disease_info = get_detailed_wikipedia_info(disease_name)
return disease_name, disease_info
# Gradio Interface for Web Deployment
def launch_web():
interface = gr.Interface(
fn=wrapped_predict,
inputs=gr.Image(type="pil", label="Upload Leaf Image"),
outputs=[
gr.Label(label="Predicted Disease"), # For displaying the disease name
gr.Textbox(label="Disease Information") # For displaying the detailed information
],
title="Plant Disease Detection",
description="Upload a leaf image to detect the disease."
)
interface.launch(share=True, debug=True)
if __name__ == "__main__":
launch_web()
|