File size: 3,995 Bytes
0ed38bf
 
ef05174
f81dce2
ef05174
f81dce2
 
 
 
a246748
 
0ed38bf
a246748
0ed38bf
 
a246748
f81dce2
0ed38bf
 
f81dce2
 
 
a246748
f81dce2
0ed38bf
a246748
f81dce2
 
a246748
f81dce2
ef05174
 
f81dce2
a246748
0ed38bf
f81dce2
a246748
0ed38bf
 
f81dce2
 
 
 
ef05174
a246748
f81dce2
ef05174
f81dce2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef05174
f81dce2
ef05174
f81dce2
ef05174
a246748
ef05174
f81dce2
 
 
 
ef05174
 
f81dce2
ef05174
 
 
 
f81dce2
 
 
 
ef05174
f81dce2
ef05174
 
0ed38bf
ef05174
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import gradio as gr
from PIL import Image
import os
from transformers import AutoImageProcessor, AutoModelForImageClassification
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import confusion_matrix
from urllib.parse import quote
import requests

# Ensure the model is loaded only once in the deployment environment
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model loading function
def load_model():
    processor = AutoImageProcessor.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification")
    model = AutoModelForImageClassification.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification").to(device)
    model.eval()  # Set the model to evaluation mode
    return processor, model

# Load model and processor
processor, model = load_model()

# Define the class names - Update with your actual class names
class_names = ['Healthy', 'Disease_1', 'Disease_2', 'Disease_3']  # Update with your actual classes

# Prediction function
def predict(image):
    if isinstance(image, str):
        image = Image.open(image).convert("RGB")
    
    # Preprocess the image
    inputs = processor(images=image, return_tensors="pt").to(device)
    
    # Inference
    with torch.no_grad():
        logits = model(**inputs).logits
    predicted_class_idx = logits.argmax(-1).item()
    predicted_class = class_names[predicted_class_idx]
    
    return predicted_class

# Function to get detailed Wikipedia info
def get_detailed_wikipedia_info(query):
    try:
        # Step 1: Search for the article
        search_url = f"https://en.wikipedia.org/w/api.php?action=query&list=search&srsearch={quote(query)}&format=json"
        search_response = requests.get(search_url).json()
        search_results = search_response.get("query", {}).get("search", [])
        
        if not search_results:
            return query, "No Wikipedia pages found."
        
        # Get the top result title
        top_title = search_results[0]["title"]
        
        # Step 2: Get full page content (plaintext extract)
        extract_url = f"https://en.wikipedia.org/w/api.php?action=query&prop=extracts&exintro=&explaintext=&titles={quote(top_title)}&format=json"
        extract_response = requests.get(extract_url).json()
        
        page_data = extract_response.get("query", {}).get("pages", {})
        page = next(iter(page_data.values()))
        extract_text = page.get("extract", "No detailed info found.")
        
        # Step 3: Format output with Markdown link
        page_link = f"https://en.wikipedia.org/wiki/{quote(top_title.replace(' ', '_'))}"
        info = (
            f"### 🌿 **{top_title}**\n\n"
            f"{extract_text.strip()}\n\n"
            f"🔗 [Click here to read more on Wikipedia]({page_link})"
        )
        
        return top_title, info
    
    except Exception as e:
        return query, f"Error fetching info: {str(e)}"

# Wrapped predict function
def wrapped_predict(image):
    # Step 1: Use the model to predict the disease name
    disease_name = predict(image)
    
    # Step 2: Use the disease name to query detailed Wikipedia info
    disease_api_name, disease_info = get_detailed_wikipedia_info(disease_name)
    
    return disease_name, disease_info

# Gradio Interface for Web Deployment
def launch_web():
    interface = gr.Interface(
        fn=wrapped_predict,
        inputs=gr.Image(type="pil", label="Upload Leaf Image"),
        outputs=[
            gr.Label(label="Predicted Disease"),  # For displaying the disease name
            gr.Textbox(label="Disease Information")  # For displaying the detailed information
        ],
        title="Plant Disease Detection",
        description="Upload a leaf image to detect the disease."
    )
    interface.launch(share=True, debug=True)

if __name__ == "__main__":
    launch_web()