--- language: en license: mit tags: - pytorch - plant-disease - image-classification - agriculture - computer-vision datasets: - plant-village --- # Plant Disease Classification Model 🌱 EfficientNet-B2 based model for classifying plant diseases in apples, tomatoes, and corn (maize). ## Model Details ### Architecture - **Backbone**: EfficientNet-B2 (pretrained) - **Custom Head**: - Attention mechanism - 3 dense layers (512, 256, num_classes) - Dropout regularization (0.3) ### Training Data - **Dataset**: PlantVillage Dataset - **Classes**: 14 total (4 Apple, 6 Tomato, 4 Corn diseases + healthy) - **Train/Val/Test Split**: 80%/10%/10% - **Image Size**: 224x224 ### Performance Metrics | Metric | Value | |--------------|---------| | Train Accuracy | 98.66% | | Val Accuracy | 99.24% | | Test Accuracy | 98.91% | ## How to Use ### Inference ```python from transformers import AutoModelForImageClassification from PIL import Image import torch import torchvision.transforms as transforms # Load model model = AutoModelForImageClassification.from_pretrained("Abuzaid01/plant-disease-classifier") # Preprocess image transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Load and transform image image = Image.open("plant.jpg") inputs = transform(image).unsqueeze(0) # Predict with torch.no_grad(): outputs = model(inputs) prediction = torch.argmax(outputs.logits, dim=1).item() print(f"Predicted class: {model.config.id2label[prediction]}")