Abuzaid01's picture
Update README.md
68627f7 verified
metadata
language: en
license: mit
tags:
  - pytorch
  - plant-disease
  - image-classification
  - agriculture
  - computer-vision
datasets:
  - plant-village

Plant Disease Classification Model

🌱 EfficientNet-B2 based model for classifying plant diseases in apples, tomatoes, and corn (maize).

Model Details

Architecture

  • Backbone: EfficientNet-B2 (pretrained)
  • Custom Head:
    • Attention mechanism
    • 3 dense layers (512, 256, num_classes)
    • Dropout regularization (0.3)

Training Data

  • Dataset: PlantVillage Dataset
  • Classes: 14 total (4 Apple, 6 Tomato, 4 Corn diseases + healthy)
  • Train/Val/Test Split: 80%/10%/10%
  • Image Size: 224x224

Performance Metrics

Metric Value
Train Accuracy 98.66%
Val Accuracy 99.24%
Test Accuracy 98.91%

How to Use

Inference

from transformers import AutoModelForImageClassification
from PIL import Image
import torch
import torchvision.transforms as transforms

# Load model
model = AutoModelForImageClassification.from_pretrained("Abuzaid01/plant-disease-classifier")

# Preprocess image
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load and transform image
image = Image.open("plant.jpg")
inputs = transform(image).unsqueeze(0)

# Predict
with torch.no_grad():
    outputs = model(inputs)
    prediction = torch.argmax(outputs.logits, dim=1).item()

print(f"Predicted class: {model.config.id2label[prediction]}")