Update app.py
Browse files
app.py
CHANGED
@@ -1,29 +1,99 @@
|
|
1 |
-
|
2 |
import torch
|
3 |
-
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
4 |
-
from PIL import Image
|
5 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
|
|
|
|
|
9 |
def predict(image):
|
10 |
processor = AutoImageProcessor.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification")
|
11 |
model = AutoModelForImageClassification.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification").to(device)
|
12 |
model.eval()
|
13 |
|
|
|
|
|
|
|
14 |
inputs = processor(images=image, return_tensors="pt").to(device)
|
|
|
15 |
with torch.no_grad():
|
16 |
logits = model(**inputs).logits
|
17 |
predicted_class = logits.argmax(-1).item()
|
18 |
class_name = model.config.id2label.get(predicted_class, "Unknown")
|
|
|
19 |
return class_name
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
|
|
|
|
|
|
1 |
import torch
|
|
|
|
|
2 |
import gradio as gr
|
3 |
+
from PIL import Image
|
4 |
+
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
5 |
+
from selenium import webdriver
|
6 |
+
from selenium.webdriver.chrome.options import Options
|
7 |
+
import time
|
8 |
+
|
9 |
+
# ========== Configuration ==========
|
10 |
|
11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
|
13 |
+
# ========== Predict using MobileNetV2 from Hugging Face ==========
|
14 |
+
|
15 |
def predict(image):
|
16 |
processor = AutoImageProcessor.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification")
|
17 |
model = AutoModelForImageClassification.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification").to(device)
|
18 |
model.eval()
|
19 |
|
20 |
+
if isinstance(image, str):
|
21 |
+
image = Image.open(image).convert("RGB")
|
22 |
+
|
23 |
inputs = processor(images=image, return_tensors="pt").to(device)
|
24 |
+
|
25 |
with torch.no_grad():
|
26 |
logits = model(**inputs).logits
|
27 |
predicted_class = logits.argmax(-1).item()
|
28 |
class_name = model.config.id2label.get(predicted_class, "Unknown")
|
29 |
+
|
30 |
return class_name
|
31 |
|
32 |
+
# ========== Scrape Disease Info ==========
|
33 |
+
|
34 |
+
def get_disease_info(disease_name):
|
35 |
+
chrome_options = Options()
|
36 |
+
chrome_options.add_argument("--headless") # Run headless (no UI)
|
37 |
+
driver = webdriver.Chrome(options=chrome_options)
|
38 |
+
|
39 |
+
# Perform a Google search for the disease
|
40 |
+
search_query = f"{disease_name} plant disease causes symptoms treatment"
|
41 |
+
search_url = f"https://www.google.com/search?q={search_query}"
|
42 |
+
|
43 |
+
try:
|
44 |
+
driver.get(search_url)
|
45 |
+
time.sleep(2) # Wait for JavaScript to load the results
|
46 |
+
|
47 |
+
# Find the first relevant link
|
48 |
+
results = driver.find_elements_by_xpath("//a[@href]")
|
49 |
+
|
50 |
+
# Find the first valid link that seems to be from a trusted source (like Wikipedia)
|
51 |
+
disease_info_url = None
|
52 |
+
for result in results:
|
53 |
+
href = result.get_attribute("href")
|
54 |
+
if href and "wikipedia" in href:
|
55 |
+
disease_info_url = href
|
56 |
+
break
|
57 |
+
|
58 |
+
if not disease_info_url:
|
59 |
+
return f"No relevant information found for '{disease_name}'."
|
60 |
+
|
61 |
+
# Now, scrape the disease details from the Wikipedia link
|
62 |
+
driver.get(disease_info_url)
|
63 |
+
time.sleep(2)
|
64 |
+
|
65 |
+
# Extract text content from the page (just the first few paragraphs)
|
66 |
+
paragraphs = driver.find_elements_by_tag_name("p")
|
67 |
+
disease_info = []
|
68 |
+
|
69 |
+
for paragraph in paragraphs[:5]: # Limit to 5 paragraphs for brevity
|
70 |
+
disease_info.append(paragraph.text.strip())
|
71 |
+
|
72 |
+
info_summary = "\n".join(disease_info)
|
73 |
+
return info_summary
|
74 |
+
|
75 |
+
except Exception as e:
|
76 |
+
return f"Could not fetch information for '{disease_name}'. Error: {str(e)}"
|
77 |
+
|
78 |
+
finally:
|
79 |
+
driver.quit()
|
80 |
+
|
81 |
+
# ========== Web Interface ==========
|
82 |
+
|
83 |
+
def wrapped_predict(image):
|
84 |
+
disease_name = predict(image)
|
85 |
+
disease_info = get_disease_info(disease_name)
|
86 |
+
return disease_name, disease_info
|
87 |
+
|
88 |
+
def launch_web():
|
89 |
+
interface = gr.Interface(
|
90 |
+
fn=wrapped_predict,
|
91 |
+
inputs=gr.Image(type="pil", label="Upload Leaf Image"),
|
92 |
+
outputs=[gr.Label(label="Predicted Disease"), gr.Textbox(label="Precautions and Solutions")],
|
93 |
+
title="Plant Disease Detection",
|
94 |
+
description="Upload a leaf image to detect the disease and receive actionable solutions fetched from the web."
|
95 |
+
)
|
96 |
+
interface.launch(share=True, debug=True)
|
97 |
|
98 |
+
if __name__ == "__main__":
|
99 |
+
launch_web()
|