File size: 3,872 Bytes
e85b01e 27dfecc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
---
# Model Card for Hybrid ViT - Diabetic Retinopathy Classification
library_name: pytorch
tags:
- classification
- diabetic-retinopathy
- hybrid vit
- efficientnet
- deep-learning
- medical-ai
model_name: Hybrid_ViT
model_description: "A Hybrid Vision Transformer (ViT) with EfficientNetB0 as the backbone, trained to classify Diabetic Retinopathy severity into 5 categories."
dataset:
name: "APTOS 2019 Blindness Detection"
url: "https://www.kaggle.com/competitions/aptos2019-blindness-detection"
description: "A dataset containing retinal fundus images for classifying Diabetic Retinopathy severity."
metrics:
accuracy: 85.4
loss: 0.23
validation_accuracy: 84.5
validation_loss: 0.25
training:
epochs: 5
batch_size: 32
optimizer: Adam
learning_rate: 1e-4
scheduler:
type: StepLR
step_size: 10
gamma: 0.5
loss_function: CrossEntropyLoss
k_folds: 5
device: "CUDA"
model_architecture:
backbone: EfficientNetB0
transformer: Vision Transformer (ViT)
num_classes: 5
input_size: 224x224
framework: PyTorch
author: "Pavan Kumar Ambadapudi>"
repository: "https://huggingface.co/PavanKumarAmbadapudi/DiabeticRetinopathy_Hybrid-ViT"
---
# 🏥 Diabetic Retinopathy Severity Classification
This model is a **Hybrid Vision Transformer (ViT) with EfficientNet B0** as the backbone. It is trained to classify the severity of **Diabetic Retinopathy** into different stages.
## 📌 Model Overview
- **Backbone**: EfficientNet B0 (Feature Extractor)
- **Head**: Vision Transformer (ViT) for Classification
- **Input Size**: 224x224 (RGB Images)
- **Output Classes**:
- 0: No Diabetic Retinopathy
- 1: Mild
- 2: Moderate
- 3: Severe
- 4: Proliferative Diabetic Retinopathy
---
## 🚀 How to Use This Model
### **1️⃣ Download the Model**
Make sure you have **PyTorch** and **Torchvision** installed:
Clone the repository and navigate to it:
```bash
!git clone https://huggingface.co/PavanKumarAmbadapudi/DiabeticRetinopathy_Hybrid-ViT
cd DiabeticRetinopathy_Hybrid-ViT
```
Or manually download the files:
Hybrid_ViT.pth, model.py
### **2️⃣ Load the Model in Python**
```python
import torch
from model import CNNViT
model = CNNViT(num_classes=5)
model.load_state_dict(torch.load("Hybrid_ViT.pth", map_location=torch.device('cpu')))
model.eval()
```
### **3️⃣ Perform Inference**
To make predictions on an image:
```python
from PIL import Image
import torchvision.transforms as transforms
def map_prediction(prediction):
mapping = {
0: "No DR",
1: "Mild",
2: "Moderate",
3: "Severe",
4: "Proliferative DR"
}
return mapping.get(prediction, "Unknown")
image_path = 'Path_to_Your_Image'
def getTransformations(image_path):
transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Use RGB mean and std
])
image = Image.open(image_path).convert("RGB")
return transform(image).unsqueeze(0)
image_tensor = getTransformations(image_path)
def predict_model_Hybrid(model, image_tensor):
with torch.no_grad():
outputs = model(image_tensor)
probabilities = torch.softmax(outputs, dim=1)
predicted_classes = probabilities.argmax(dim=1).item()
confidences = probabilities.max(dim=1).values.item()
model_predictions = {"label": map_prediction(predicted_classes), "confidence": confidences}
return model_predictions
print("Hybrid ViT ", predict_model_Hybrid(model, image_tensor))
```
## 📬 Contact
For any queries, reach out to me at:
📧 [email protected] |