import torch import gradio as gr from PIL import Image import os from transformers import AutoImageProcessor, AutoModelForImageClassification import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score from sklearn.metrics import confusion_matrix from urllib.parse import quote import requests # Ensure the model is loaded only once in the deployment environment device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Model loading function def load_model(): processor = AutoImageProcessor.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification") model = AutoModelForImageClassification.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification").to(device) model.eval() # Set the model to evaluation mode return processor, model # Load model and processor processor, model = load_model() # Define the class names - Update with your actual class names class_names = ['Healthy', 'Disease_1', 'Disease_2', 'Disease_3'] # Update with your actual classes # Prediction function def predict(image): if isinstance(image, str): image = Image.open(image).convert("RGB") # Preprocess the image inputs = processor(images=image, return_tensors="pt").to(device) # Inference with torch.no_grad(): logits = model(**inputs).logits predicted_class_idx = logits.argmax(-1).item() predicted_class = class_names[predicted_class_idx] return predicted_class # Function to get detailed Wikipedia info def get_detailed_wikipedia_info(query): try: # Step 1: Search for the article search_url = f"https://en.wikipedia.org/w/api.php?action=query&list=search&srsearch={quote(query)}&format=json" search_response = requests.get(search_url).json() search_results = search_response.get("query", {}).get("search", []) if not search_results: return query, "No Wikipedia pages found." # Get the top result title top_title = search_results[0]["title"] # Step 2: Get full page content (plaintext extract) extract_url = f"https://en.wikipedia.org/w/api.php?action=query&prop=extracts&exintro=&explaintext=&titles={quote(top_title)}&format=json" extract_response = requests.get(extract_url).json() page_data = extract_response.get("query", {}).get("pages", {}) page = next(iter(page_data.values())) extract_text = page.get("extract", "No detailed info found.") # Step 3: Format output with Markdown link page_link = f"https://en.wikipedia.org/wiki/{quote(top_title.replace(' ', '_'))}" info = ( f"### 🌿 **{top_title}**\n\n" f"{extract_text.strip()}\n\n" f"🔗 [Click here to read more on Wikipedia]({page_link})" ) return top_title, info except Exception as e: return query, f"Error fetching info: {str(e)}" # Wrapped predict function def wrapped_predict(image): # Step 1: Use the model to predict the disease name disease_name = predict(image) # Step 2: Use the disease name to query detailed Wikipedia info disease_api_name, disease_info = get_detailed_wikipedia_info(disease_name) return disease_name, disease_info # Gradio Interface for Web Deployment def launch_web(): interface = gr.Interface( fn=wrapped_predict, inputs=gr.Image(type="pil", label="Upload Leaf Image"), outputs=[ gr.Label(label="Predicted Disease"), # For displaying the disease name gr.Textbox(label="Disease Information") # For displaying the detailed information ], title="Plant Disease Detection", description="Upload a leaf image to detect the disease." ) interface.launch(share=True, debug=True) if __name__ == "__main__": launch_web()