Ace10er commited on
Commit
ef05174
·
verified ·
1 Parent(s): 0ed38bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -11
app.py CHANGED
@@ -1,29 +1,99 @@
1
-
2
  import torch
3
- from transformers import AutoImageProcessor, AutoModelForImageClassification
4
- from PIL import Image
5
  import gradio as gr
 
 
 
 
 
 
 
6
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
 
 
9
  def predict(image):
10
  processor = AutoImageProcessor.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification")
11
  model = AutoModelForImageClassification.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification").to(device)
12
  model.eval()
13
 
 
 
 
14
  inputs = processor(images=image, return_tensors="pt").to(device)
 
15
  with torch.no_grad():
16
  logits = model(**inputs).logits
17
  predicted_class = logits.argmax(-1).item()
18
  class_name = model.config.id2label.get(predicted_class, "Unknown")
 
19
  return class_name
20
 
21
- iface = gr.Interface(
22
- fn=predict,
23
- inputs=gr.Image(type="pil"),
24
- outputs=gr.Label(),
25
- title="Plant Disease Classifier",
26
- description="Upload a leaf image to predict plant disease."
27
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- iface.launch()
 
 
 
1
  import torch
 
 
2
  import gradio as gr
3
+ from PIL import Image
4
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
5
+ from selenium import webdriver
6
+ from selenium.webdriver.chrome.options import Options
7
+ import time
8
+
9
+ # ========== Configuration ==========
10
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
+ # ========== Predict using MobileNetV2 from Hugging Face ==========
14
+
15
  def predict(image):
16
  processor = AutoImageProcessor.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification")
17
  model = AutoModelForImageClassification.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification").to(device)
18
  model.eval()
19
 
20
+ if isinstance(image, str):
21
+ image = Image.open(image).convert("RGB")
22
+
23
  inputs = processor(images=image, return_tensors="pt").to(device)
24
+
25
  with torch.no_grad():
26
  logits = model(**inputs).logits
27
  predicted_class = logits.argmax(-1).item()
28
  class_name = model.config.id2label.get(predicted_class, "Unknown")
29
+
30
  return class_name
31
 
32
+ # ========== Scrape Disease Info ==========
33
+
34
+ def get_disease_info(disease_name):
35
+ chrome_options = Options()
36
+ chrome_options.add_argument("--headless") # Run headless (no UI)
37
+ driver = webdriver.Chrome(options=chrome_options)
38
+
39
+ # Perform a Google search for the disease
40
+ search_query = f"{disease_name} plant disease causes symptoms treatment"
41
+ search_url = f"https://www.google.com/search?q={search_query}"
42
+
43
+ try:
44
+ driver.get(search_url)
45
+ time.sleep(2) # Wait for JavaScript to load the results
46
+
47
+ # Find the first relevant link
48
+ results = driver.find_elements_by_xpath("//a[@href]")
49
+
50
+ # Find the first valid link that seems to be from a trusted source (like Wikipedia)
51
+ disease_info_url = None
52
+ for result in results:
53
+ href = result.get_attribute("href")
54
+ if href and "wikipedia" in href:
55
+ disease_info_url = href
56
+ break
57
+
58
+ if not disease_info_url:
59
+ return f"No relevant information found for '{disease_name}'."
60
+
61
+ # Now, scrape the disease details from the Wikipedia link
62
+ driver.get(disease_info_url)
63
+ time.sleep(2)
64
+
65
+ # Extract text content from the page (just the first few paragraphs)
66
+ paragraphs = driver.find_elements_by_tag_name("p")
67
+ disease_info = []
68
+
69
+ for paragraph in paragraphs[:5]: # Limit to 5 paragraphs for brevity
70
+ disease_info.append(paragraph.text.strip())
71
+
72
+ info_summary = "\n".join(disease_info)
73
+ return info_summary
74
+
75
+ except Exception as e:
76
+ return f"Could not fetch information for '{disease_name}'. Error: {str(e)}"
77
+
78
+ finally:
79
+ driver.quit()
80
+
81
+ # ========== Web Interface ==========
82
+
83
+ def wrapped_predict(image):
84
+ disease_name = predict(image)
85
+ disease_info = get_disease_info(disease_name)
86
+ return disease_name, disease_info
87
+
88
+ def launch_web():
89
+ interface = gr.Interface(
90
+ fn=wrapped_predict,
91
+ inputs=gr.Image(type="pil", label="Upload Leaf Image"),
92
+ outputs=[gr.Label(label="Predicted Disease"), gr.Textbox(label="Precautions and Solutions")],
93
+ title="Plant Disease Detection",
94
+ description="Upload a leaf image to detect the disease and receive actionable solutions fetched from the web."
95
+ )
96
+ interface.launch(share=True, debug=True)
97
 
98
+ if __name__ == "__main__":
99
+ launch_web()