Update app.py
Browse files
app.py
CHANGED
@@ -7,32 +7,34 @@ import numpy as np
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
9 |
from sklearn.metrics import confusion_matrix
|
|
|
|
|
10 |
|
11 |
-
#
|
12 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
|
14 |
-
#
|
15 |
-
model_path = "./models" # Store model weights inside the space directory
|
16 |
-
os.makedirs(model_path, exist_ok=True)
|
17 |
-
|
18 |
-
# Load the trained MobileNetV2 model from Hugging Face Hub
|
19 |
def load_model():
|
20 |
processor = AutoImageProcessor.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification")
|
21 |
model = AutoModelForImageClassification.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification").to(device)
|
22 |
model.eval() # Set the model to evaluation mode
|
23 |
return processor, model
|
24 |
|
|
|
25 |
processor, model = load_model()
|
26 |
|
27 |
-
# Define the class names -
|
28 |
class_names = ['Healthy', 'Disease_1', 'Disease_2', 'Disease_3'] # Update with your actual classes
|
29 |
|
|
|
30 |
def predict(image):
|
31 |
if isinstance(image, str):
|
32 |
image = Image.open(image).convert("RGB")
|
33 |
|
|
|
34 |
inputs = processor(images=image, return_tensors="pt").to(device)
|
35 |
|
|
|
36 |
with torch.no_grad():
|
37 |
logits = model(**inputs).logits
|
38 |
predicted_class_idx = logits.argmax(-1).item()
|
@@ -40,10 +42,7 @@ def predict(image):
|
|
40 |
|
41 |
return predicted_class
|
42 |
|
43 |
-
#
|
44 |
-
from urllib.parse import quote
|
45 |
-
import requests
|
46 |
-
|
47 |
def get_detailed_wikipedia_info(query):
|
48 |
try:
|
49 |
# Step 1: Search for the article
|
@@ -80,7 +79,7 @@ def get_detailed_wikipedia_info(query):
|
|
80 |
|
81 |
# Wrapped predict function
|
82 |
def wrapped_predict(image):
|
83 |
-
# Step 1: Use
|
84 |
disease_name = predict(image)
|
85 |
|
86 |
# Step 2: Use the disease name to query detailed Wikipedia info
|
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
9 |
from sklearn.metrics import confusion_matrix
|
10 |
+
from urllib.parse import quote
|
11 |
+
import requests
|
12 |
|
13 |
+
# Ensure the model is loaded only once in the deployment environment
|
14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
|
16 |
+
# Model loading function
|
|
|
|
|
|
|
|
|
17 |
def load_model():
|
18 |
processor = AutoImageProcessor.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification")
|
19 |
model = AutoModelForImageClassification.from_pretrained("linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification").to(device)
|
20 |
model.eval() # Set the model to evaluation mode
|
21 |
return processor, model
|
22 |
|
23 |
+
# Load model and processor
|
24 |
processor, model = load_model()
|
25 |
|
26 |
+
# Define the class names - Update with your actual class names
|
27 |
class_names = ['Healthy', 'Disease_1', 'Disease_2', 'Disease_3'] # Update with your actual classes
|
28 |
|
29 |
+
# Prediction function
|
30 |
def predict(image):
|
31 |
if isinstance(image, str):
|
32 |
image = Image.open(image).convert("RGB")
|
33 |
|
34 |
+
# Preprocess the image
|
35 |
inputs = processor(images=image, return_tensors="pt").to(device)
|
36 |
|
37 |
+
# Inference
|
38 |
with torch.no_grad():
|
39 |
logits = model(**inputs).logits
|
40 |
predicted_class_idx = logits.argmax(-1).item()
|
|
|
42 |
|
43 |
return predicted_class
|
44 |
|
45 |
+
# Function to get detailed Wikipedia info
|
|
|
|
|
|
|
46 |
def get_detailed_wikipedia_info(query):
|
47 |
try:
|
48 |
# Step 1: Search for the article
|
|
|
79 |
|
80 |
# Wrapped predict function
|
81 |
def wrapped_predict(image):
|
82 |
+
# Step 1: Use the model to predict the disease name
|
83 |
disease_name = predict(image)
|
84 |
|
85 |
# Step 2: Use the disease name to query detailed Wikipedia info
|