Ace10er commited on
Commit
a246748
·
verified ·
1 Parent(s): f81dce2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -12
app.py CHANGED
@@ -7,32 +7,34 @@ import numpy as np
7
  import matplotlib.pyplot as plt
8
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
9
  from sklearn.metrics import confusion_matrix
 
 
10
 
11
- # Make sure the model is loaded only once in the deployment environment
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
- # Assuming model path is provided by Hugging Face Spaces
15
- model_path = "./models" # Store model weights inside the space directory
16
- os.makedirs(model_path, exist_ok=True)
17
-
18
- # Load the trained MobileNetV2 model from Hugging Face Hub
19
  def load_model():
20
  processor = AutoImageProcessor.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification")
21
  model = AutoModelForImageClassification.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification").to(device)
22
  model.eval() # Set the model to evaluation mode
23
  return processor, model
24
 
 
25
  processor, model = load_model()
26
 
27
- # Define the class names - These should match the labels in the dataset
28
  class_names = ['Healthy', 'Disease_1', 'Disease_2', 'Disease_3'] # Update with your actual classes
29
 
 
30
  def predict(image):
31
  if isinstance(image, str):
32
  image = Image.open(image).convert("RGB")
33
 
 
34
  inputs = processor(images=image, return_tensors="pt").to(device)
35
 
 
36
  with torch.no_grad():
37
  logits = model(**inputs).logits
38
  predicted_class_idx = logits.argmax(-1).item()
@@ -40,10 +42,7 @@ def predict(image):
40
 
41
  return predicted_class
42
 
43
- # Detailed Wikipedia info
44
- from urllib.parse import quote
45
- import requests
46
-
47
  def get_detailed_wikipedia_info(query):
48
  try:
49
  # Step 1: Search for the article
@@ -80,7 +79,7 @@ def get_detailed_wikipedia_info(query):
80
 
81
  # Wrapped predict function
82
  def wrapped_predict(image):
83
- # Step 1: Use your local ML model to predict the disease name
84
  disease_name = predict(image)
85
 
86
  # Step 2: Use the disease name to query detailed Wikipedia info
 
7
  import matplotlib.pyplot as plt
8
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
9
  from sklearn.metrics import confusion_matrix
10
+ from urllib.parse import quote
11
+ import requests
12
 
13
+ # Ensure the model is loaded only once in the deployment environment
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
+ # Model loading function
 
 
 
 
17
  def load_model():
18
  processor = AutoImageProcessor.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification")
19
  model = AutoModelForImageClassification.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification").to(device)
20
  model.eval() # Set the model to evaluation mode
21
  return processor, model
22
 
23
+ # Load model and processor
24
  processor, model = load_model()
25
 
26
+ # Define the class names - Update with your actual class names
27
  class_names = ['Healthy', 'Disease_1', 'Disease_2', 'Disease_3'] # Update with your actual classes
28
 
29
+ # Prediction function
30
  def predict(image):
31
  if isinstance(image, str):
32
  image = Image.open(image).convert("RGB")
33
 
34
+ # Preprocess the image
35
  inputs = processor(images=image, return_tensors="pt").to(device)
36
 
37
+ # Inference
38
  with torch.no_grad():
39
  logits = model(**inputs).logits
40
  predicted_class_idx = logits.argmax(-1).item()
 
42
 
43
  return predicted_class
44
 
45
+ # Function to get detailed Wikipedia info
 
 
 
46
  def get_detailed_wikipedia_info(query):
47
  try:
48
  # Step 1: Search for the article
 
79
 
80
  # Wrapped predict function
81
  def wrapped_predict(image):
82
+ # Step 1: Use the model to predict the disease name
83
  disease_name = predict(image)
84
 
85
  # Step 2: Use the disease name to query detailed Wikipedia info