Ace10er commited on
Commit
f81dce2
·
verified ·
1 Parent(s): 4917e5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -64
app.py CHANGED
@@ -1,97 +1,104 @@
1
  import torch
2
  import gradio as gr
3
  from PIL import Image
 
4
  from transformers import AutoImageProcessor, AutoModelForImageClassification
5
- from selenium import webdriver
6
- from selenium.webdriver.chrome.options import Options
7
- import time
8
-
9
- # ========== Configuration ==========
10
 
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
- # ========== Predict using MobileNetV2 from Hugging Face ==========
 
 
14
 
15
- def predict(image):
 
16
  processor = AutoImageProcessor.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification")
17
  model = AutoModelForImageClassification.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification").to(device)
18
- model.eval()
 
 
 
19
 
 
 
 
 
20
  if isinstance(image, str):
21
  image = Image.open(image).convert("RGB")
22
-
23
  inputs = processor(images=image, return_tensors="pt").to(device)
24
-
25
  with torch.no_grad():
26
  logits = model(**inputs).logits
27
- predicted_class = logits.argmax(-1).item()
28
- class_name = model.config.id2label.get(predicted_class, "Unknown")
29
-
30
- return class_name
31
-
32
- # ========== Scrape Disease Info ==========
33
-
34
- def get_disease_info(disease_name):
35
- chrome_options = Options()
36
- chrome_options.add_argument("--headless") # Run headless (no UI)
37
- driver = webdriver.Chrome(options=chrome_options)
38
 
39
- # Perform a Google search for the disease
40
- search_query = f"{disease_name} plant disease causes symptoms treatment"
41
- search_url = f"https://www.google.com/search?q={search_query}"
42
 
 
43
  try:
44
- driver.get(search_url)
45
- time.sleep(2) # Wait for JavaScript to load the results
46
-
47
- # Find the first relevant link
48
- results = driver.find_elements_by_xpath("//a[@href]")
49
-
50
- # Find the first valid link that seems to be from a trusted source (like Wikipedia)
51
- disease_info_url = None
52
- for result in results:
53
- href = result.get_attribute("href")
54
- if href and "wikipedia" in href:
55
- disease_info_url = href
56
- break
57
-
58
- if not disease_info_url:
59
- return f"No relevant information found for '{disease_name}'."
60
-
61
- # Now, scrape the disease details from the Wikipedia link
62
- driver.get(disease_info_url)
63
- time.sleep(2)
64
-
65
- # Extract text content from the page (just the first few paragraphs)
66
- paragraphs = driver.find_elements_by_tag_name("p")
67
- disease_info = []
68
-
69
- for paragraph in paragraphs[:5]: # Limit to 5 paragraphs for brevity
70
- disease_info.append(paragraph.text.strip())
71
-
72
- info_summary = "\n".join(disease_info)
73
- return info_summary
74
-
75
  except Exception as e:
76
- return f"Could not fetch information for '{disease_name}'. Error: {str(e)}"
77
-
78
- finally:
79
- driver.quit()
80
-
81
- # ========== Web Interface ==========
82
 
 
83
  def wrapped_predict(image):
 
84
  disease_name = predict(image)
85
- disease_info = get_disease_info(disease_name)
 
 
 
86
  return disease_name, disease_info
87
 
 
88
  def launch_web():
89
  interface = gr.Interface(
90
  fn=wrapped_predict,
91
  inputs=gr.Image(type="pil", label="Upload Leaf Image"),
92
- outputs=[gr.Label(label="Predicted Disease"), gr.Textbox(label="Precautions and Solutions")],
 
 
 
93
  title="Plant Disease Detection",
94
- description="Upload a leaf image to detect the disease and receive actionable solutions fetched from the web."
95
  )
96
  interface.launch(share=True, debug=True)
97
 
 
1
  import torch
2
  import gradio as gr
3
  from PIL import Image
4
+ import os
5
  from transformers import AutoImageProcessor, AutoModelForImageClassification
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
9
+ from sklearn.metrics import confusion_matrix
 
10
 
11
+ # Make sure the model is loaded only once in the deployment environment
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
+ # Assuming model path is provided by Hugging Face Spaces
15
+ model_path = "./models" # Store model weights inside the space directory
16
+ os.makedirs(model_path, exist_ok=True)
17
 
18
+ # Load the trained MobileNetV2 model from Hugging Face Hub
19
+ def load_model():
20
  processor = AutoImageProcessor.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification")
21
  model = AutoModelForImageClassification.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification").to(device)
22
+ model.eval() # Set the model to evaluation mode
23
+ return processor, model
24
+
25
+ processor, model = load_model()
26
 
27
+ # Define the class names - These should match the labels in the dataset
28
+ class_names = ['Healthy', 'Disease_1', 'Disease_2', 'Disease_3'] # Update with your actual classes
29
+
30
+ def predict(image):
31
  if isinstance(image, str):
32
  image = Image.open(image).convert("RGB")
33
+
34
  inputs = processor(images=image, return_tensors="pt").to(device)
35
+
36
  with torch.no_grad():
37
  logits = model(**inputs).logits
38
+ predicted_class_idx = logits.argmax(-1).item()
39
+ predicted_class = class_names[predicted_class_idx]
40
+
41
+ return predicted_class
 
 
 
 
 
 
 
42
 
43
+ # Detailed Wikipedia info
44
+ from urllib.parse import quote
45
+ import requests
46
 
47
+ def get_detailed_wikipedia_info(query):
48
  try:
49
+ # Step 1: Search for the article
50
+ search_url = f"https://en.wikipedia.org/w/api.php?action=query&list=search&srsearch={quote(query)}&format=json"
51
+ search_response = requests.get(search_url).json()
52
+ search_results = search_response.get("query", {}).get("search", [])
53
+
54
+ if not search_results:
55
+ return query, "No Wikipedia pages found."
56
+
57
+ # Get the top result title
58
+ top_title = search_results[0]["title"]
59
+
60
+ # Step 2: Get full page content (plaintext extract)
61
+ extract_url = f"https://en.wikipedia.org/w/api.php?action=query&prop=extracts&exintro=&explaintext=&titles={quote(top_title)}&format=json"
62
+ extract_response = requests.get(extract_url).json()
63
+
64
+ page_data = extract_response.get("query", {}).get("pages", {})
65
+ page = next(iter(page_data.values()))
66
+ extract_text = page.get("extract", "No detailed info found.")
67
+
68
+ # Step 3: Format output with Markdown link
69
+ page_link = f"https://en.wikipedia.org/wiki/{quote(top_title.replace(' ', '_'))}"
70
+ info = (
71
+ f"### 🌿 **{top_title}**\n\n"
72
+ f"{extract_text.strip()}\n\n"
73
+ f"🔗 [Click here to read more on Wikipedia]({page_link})"
74
+ )
75
+
76
+ return top_title, info
77
+
 
 
78
  except Exception as e:
79
+ return query, f"Error fetching info: {str(e)}"
 
 
 
 
 
80
 
81
+ # Wrapped predict function
82
  def wrapped_predict(image):
83
+ # Step 1: Use your local ML model to predict the disease name
84
  disease_name = predict(image)
85
+
86
+ # Step 2: Use the disease name to query detailed Wikipedia info
87
+ disease_api_name, disease_info = get_detailed_wikipedia_info(disease_name)
88
+
89
  return disease_name, disease_info
90
 
91
+ # Gradio Interface for Web Deployment
92
  def launch_web():
93
  interface = gr.Interface(
94
  fn=wrapped_predict,
95
  inputs=gr.Image(type="pil", label="Upload Leaf Image"),
96
+ outputs=[
97
+ gr.Label(label="Predicted Disease"), # For displaying the disease name
98
+ gr.Textbox(label="Disease Information") # For displaying the detailed information
99
+ ],
100
  title="Plant Disease Detection",
101
+ description="Upload a leaf image to detect the disease."
102
  )
103
  interface.launch(share=True, debug=True)
104