Abuzaid01 commited on
Commit
cf83e69
·
verified ·
1 Parent(s): 6b55b7b

Upload plant disease classifier model

Browse files
Files changed (4) hide show
  1. README.md +22 -0
  2. config.json +25 -0
  3. inference.py +87 -0
  4. model.pth +3 -0
README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Plant Disease Classification Model
2
+
3
+ A PyTorch model for classifying plant diseases in Apple, Tomato, and Corn crops.
4
+
5
+ ## Model Details
6
+
7
+ - **Model Type**: Image Classification
8
+ - **Architecture**: EfficientNet-B2 with Attention
9
+ - **Input Size**: 224x224 RGB images
10
+ - **Output**: Disease classification
11
+
12
+ ## Usage
13
+
14
+ ```python
15
+ from inference import load_model, predict_image
16
+
17
+ # Load model
18
+ model, class_names = load_model("model.pth")
19
+
20
+ # Make prediction
21
+ results = predict_image("your_image.jpg", model, class_names)
22
+ print(results)
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "image-classification",
3
+ "num_classes": 14,
4
+ "class_names": [
5
+ "Apple_Apple_Scab",
6
+ "Apple_Black_Rot",
7
+ "Apple_Cedar_Apple_Rust",
8
+ "Apple_Healthy",
9
+ "Corn_(maize)_Cercospora_Leaf_Spot",
10
+ "Corn_(maize)_Common_Rust_",
11
+ "Corn_(maize)_Healthy",
12
+ "Corn_(maize)_Northern_Leaf_Blight",
13
+ "Tomato_Bacterial_Spot",
14
+ "Tomato_Early_Blight",
15
+ "Tomato_Healthy",
16
+ "Tomato_Late_Blight",
17
+ "Tomato_Septoria_Leaf_Spot",
18
+ "Tomato_Yellow_Leaf_Curl_Virus"
19
+ ],
20
+ "input_size": [
21
+ 224,
22
+ 224
23
+ ],
24
+ "framework": "pytorch"
25
+ }
inference.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+
8
+ class PlantDiseaseClassifier(nn.Module):
9
+ def __init__(self, num_classes, dropout_rate=0.3):
10
+ super(PlantDiseaseClassifier, self).__init__()
11
+
12
+ # Use EfficientNet as backbone
13
+ from torchvision import models
14
+ self.backbone = models.efficientnet_b2(pretrained=False)
15
+
16
+ # Get feature dimension
17
+ num_features = self.backbone.classifier[1].in_features
18
+
19
+ # Replace classifier with custom head
20
+ self.backbone.classifier = nn.Identity()
21
+
22
+ # Attention mechanism
23
+ self.attention = nn.Sequential(
24
+ nn.AdaptiveAvgPool2d(1),
25
+ nn.Flatten(),
26
+ nn.Linear(num_features, num_features // 4),
27
+ nn.ReLU(),
28
+ nn.Linear(num_features // 4, num_features),
29
+ nn.Sigmoid()
30
+ )
31
+
32
+ # Custom classifier head
33
+ self.classifier = nn.Sequential(
34
+ nn.Dropout(dropout_rate),
35
+ nn.Linear(num_features, 512),
36
+ nn.BatchNorm1d(512),
37
+ nn.ReLU(),
38
+ nn.Dropout(dropout_rate * 0.5),
39
+ nn.Linear(512, 256),
40
+ nn.BatchNorm1d(256),
41
+ nn.ReLU(),
42
+ nn.Dropout(dropout_rate * 0.3),
43
+ nn.Linear(256, num_classes)
44
+ )
45
+
46
+ def forward(self, x):
47
+ features = self.backbone.features(x)
48
+ pooled = F.adaptive_avg_pool2d(features, 1)
49
+ pooled = torch.flatten(pooled, 1)
50
+ attention_weights = self.attention(features)
51
+ attended_features = pooled * attention_weights
52
+ output = self.classifier(attended_features)
53
+ return output
54
+
55
+ def load_model(model_path):
56
+ checkpoint = torch.load(model_path, map_location='cpu')
57
+ num_classes = len(checkpoint['class_names'])
58
+ model = PlantDiseaseClassifier(num_classes=num_classes)
59
+ model.load_state_dict(checkpoint['model_state_dict'])
60
+ model.eval()
61
+ return model, checkpoint['class_names']
62
+
63
+ def predict_image(image_path, model, class_names):
64
+ transform = transforms.Compose([
65
+ transforms.Resize((224, 224)),
66
+ transforms.ToTensor(),
67
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
68
+ ])
69
+
70
+ image = Image.open(image_path).convert('RGB')
71
+ image_tensor = transform(image).unsqueeze(0)
72
+
73
+ with torch.no_grad():
74
+ outputs = model(image_tensor)
75
+ probabilities = F.softmax(outputs, dim=1)[0]
76
+
77
+ # Get top predictions
78
+ top_probs, top_indices = torch.topk(probabilities, 3)
79
+
80
+ results = []
81
+ for i in range(len(top_indices)):
82
+ results.append({
83
+ "label": class_names[top_indices[i].item()],
84
+ "score": top_probs[i].item()
85
+ })
86
+
87
+ return results
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a1511cc75d57f3e05f1d7ef48842e52e34edddf4eeb3537795068e3e1f2ebf8
3
+ size 38649138