import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from PIL import Image class PlantDiseaseClassifier(nn.Module): def __init__(self, num_classes, dropout_rate=0.3): super(PlantDiseaseClassifier, self).__init__() # Use EfficientNet as backbone from torchvision import models self.backbone = models.efficientnet_b2(pretrained=False) # Get feature dimension num_features = self.backbone.classifier[1].in_features # Replace classifier with custom head self.backbone.classifier = nn.Identity() # Attention mechanism self.attention = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(num_features, num_features // 4), nn.ReLU(), nn.Linear(num_features // 4, num_features), nn.Sigmoid() ) # Custom classifier head self.classifier = nn.Sequential( nn.Dropout(dropout_rate), nn.Linear(num_features, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(dropout_rate * 0.5), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(dropout_rate * 0.3), nn.Linear(256, num_classes) ) def forward(self, x): features = self.backbone.features(x) pooled = F.adaptive_avg_pool2d(features, 1) pooled = torch.flatten(pooled, 1) attention_weights = self.attention(features) attended_features = pooled * attention_weights output = self.classifier(attended_features) return output def load_model(model_path): checkpoint = torch.load(model_path, map_location='cpu') num_classes = len(checkpoint['class_names']) model = PlantDiseaseClassifier(num_classes=num_classes) model.load_state_dict(checkpoint['model_state_dict']) model.eval() return model, checkpoint['class_names'] def predict_image(image_path, model, class_names): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = Image.open(image_path).convert('RGB') image_tensor = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(image_tensor) probabilities = F.softmax(outputs, dim=1)[0] # Get top predictions top_probs, top_indices = torch.topk(probabilities, 3) results = [] for i in range(len(top_indices)): results.append({ "label": class_names[top_indices[i].item()], "score": top_probs[i].item() }) return results