|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision import transforms |
|
from PIL import Image |
|
|
|
class PlantDiseaseClassifier(nn.Module): |
|
def __init__(self, num_classes, dropout_rate=0.3): |
|
super(PlantDiseaseClassifier, self).__init__() |
|
|
|
|
|
from torchvision import models |
|
self.backbone = models.efficientnet_b2(pretrained=False) |
|
|
|
|
|
num_features = self.backbone.classifier[1].in_features |
|
|
|
|
|
self.backbone.classifier = nn.Identity() |
|
|
|
|
|
self.attention = nn.Sequential( |
|
nn.AdaptiveAvgPool2d(1), |
|
nn.Flatten(), |
|
nn.Linear(num_features, num_features // 4), |
|
nn.ReLU(), |
|
nn.Linear(num_features // 4, num_features), |
|
nn.Sigmoid() |
|
) |
|
|
|
|
|
self.classifier = nn.Sequential( |
|
nn.Dropout(dropout_rate), |
|
nn.Linear(num_features, 512), |
|
nn.BatchNorm1d(512), |
|
nn.ReLU(), |
|
nn.Dropout(dropout_rate * 0.5), |
|
nn.Linear(512, 256), |
|
nn.BatchNorm1d(256), |
|
nn.ReLU(), |
|
nn.Dropout(dropout_rate * 0.3), |
|
nn.Linear(256, num_classes) |
|
) |
|
|
|
def forward(self, x): |
|
features = self.backbone.features(x) |
|
pooled = F.adaptive_avg_pool2d(features, 1) |
|
pooled = torch.flatten(pooled, 1) |
|
attention_weights = self.attention(features) |
|
attended_features = pooled * attention_weights |
|
output = self.classifier(attended_features) |
|
return output |
|
|
|
def load_model(model_path): |
|
checkpoint = torch.load(model_path, map_location='cpu') |
|
num_classes = len(checkpoint['class_names']) |
|
model = PlantDiseaseClassifier(num_classes=num_classes) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
model.eval() |
|
return model, checkpoint['class_names'] |
|
|
|
def predict_image(image_path, model, class_names): |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
image = Image.open(image_path).convert('RGB') |
|
image_tensor = transform(image).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
outputs = model(image_tensor) |
|
probabilities = F.softmax(outputs, dim=1)[0] |
|
|
|
|
|
top_probs, top_indices = torch.topk(probabilities, 3) |
|
|
|
results = [] |
|
for i in range(len(top_indices)): |
|
results.append({ |
|
"label": class_names[top_indices[i].item()], |
|
"score": top_probs[i].item() |
|
}) |
|
|
|
return results |
|
|