metadata
library_name: pytorch
tags:
- classification
- diabetic-retinopathy
- hybrid vit
- efficientnet
- deep-learning
- medical-ai
model_name: Hybrid_ViT
model_description: >-
A Hybrid Vision Transformer (ViT) with EfficientNetB0 as the backbone, trained
to classify Diabetic Retinopathy severity into 5 categories.
dataset:
name: APTOS 2019 Blindness Detection
url: https://www.kaggle.com/competitions/aptos2019-blindness-detection
description: >-
A dataset containing retinal fundus images for classifying Diabetic
Retinopathy severity.
metrics:
accuracy: 85.4
loss: 0.23
validation_accuracy: 84.5
validation_loss: 0.25
training:
epochs: 5
batch_size: 32
optimizer: Adam
learning_rate: 0.0001
scheduler:
type: StepLR
step_size: 10
gamma: 0.5
loss_function: CrossEntropyLoss
k_folds: 5
device: CUDA
model_architecture:
backbone: EfficientNetB0
transformer: Vision Transformer (ViT)
num_classes: 5
input_size: 224x224
framework: PyTorch
author: Pavan Kumar Ambadapudi>
repository: https://huggingface.co/PavanKumarAmbadapudi/DiabeticRetinopathy_Hybrid-ViT
π₯ Diabetic Retinopathy Severity Classification
This model is a Hybrid Vision Transformer (ViT) with EfficientNet B0 as the backbone. It is trained to classify the severity of Diabetic Retinopathy into different stages.
π Model Overview
- Backbone: EfficientNet B0 (Feature Extractor)
- Head: Vision Transformer (ViT) for Classification
- Input Size: 224x224 (RGB Images)
- Output Classes:
- 0: No Diabetic Retinopathy
- 1: Mild
- 2: Moderate
- 3: Severe
- 4: Proliferative Diabetic Retinopathy
π How to Use This Model
1οΈβ£ Download the Model
Make sure you have PyTorch and Torchvision installed: Clone the repository and navigate to it:
!git clone https://huggingface.co/PavanKumarAmbadapudi/DiabeticRetinopathy_Hybrid-ViT
cd DiabeticRetinopathy_Hybrid-ViT
Or manually download the files:
Hybrid_ViT.pth, model.py
2οΈβ£ Load the Model in Python
import torch
from model import CNNViT
model = CNNViT(num_classes=5)
model.load_state_dict(torch.load("Hybrid_ViT.pth", map_location=torch.device('cpu')))
model.eval()
3οΈβ£ Perform Inference
To make predictions on an image:
from PIL import Image
import torchvision.transforms as transforms
def map_prediction(prediction):
mapping = {
0: "No DR",
1: "Mild",
2: "Moderate",
3: "Severe",
4: "Proliferative DR"
}
return mapping.get(prediction, "Unknown")
image_path = 'Path_to_Your_Image'
def getTransformations(image_path):
transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Use RGB mean and std
])
image = Image.open(image_path).convert("RGB")
return transform(image).unsqueeze(0)
image_tensor = getTransformations(image_path)
def predict_model_Hybrid(model, image_tensor):
with torch.no_grad():
outputs = model(image_tensor)
probabilities = torch.softmax(outputs, dim=1)
predicted_classes = probabilities.argmax(dim=1).item()
confidences = probabilities.max(dim=1).values.item()
model_predictions = {"label": map_prediction(predicted_classes), "confidence": confidences}
return model_predictions
print("Hybrid ViT ", predict_model_Hybrid(model, image_tensor))
π¬ Contact
For any queries, reach out to me at: π§ [email protected]