File size: 3,872 Bytes
e85b01e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27dfecc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
---
# Model Card for Hybrid ViT - Diabetic Retinopathy Classification

library_name: pytorch
tags:
  - classification
  - diabetic-retinopathy
  - hybrid vit
  - efficientnet
  - deep-learning
  - medical-ai
model_name: Hybrid_ViT
model_description: "A Hybrid Vision Transformer (ViT) with EfficientNetB0 as the backbone, trained to classify Diabetic Retinopathy severity into 5 categories."
dataset:
  name: "APTOS 2019 Blindness Detection"
  url: "https://www.kaggle.com/competitions/aptos2019-blindness-detection"
  description: "A dataset containing retinal fundus images for classifying Diabetic Retinopathy severity."

metrics:
  accuracy: 85.4
  loss: 0.23
  validation_accuracy: 84.5
  validation_loss: 0.25

training:
  epochs: 5
  batch_size: 32
  optimizer: Adam
  learning_rate: 1e-4
  scheduler: 
    type: StepLR
    step_size: 10
    gamma: 0.5
  loss_function: CrossEntropyLoss
  k_folds: 5
  device: "CUDA"
model_architecture:
  backbone: EfficientNetB0
  transformer: Vision Transformer (ViT)
  num_classes: 5
  input_size: 224x224
  framework: PyTorch


author: "Pavan Kumar Ambadapudi>"
repository: "https://huggingface.co/PavanKumarAmbadapudi/DiabeticRetinopathy_Hybrid-ViT"

---


# 🏥 Diabetic Retinopathy Severity Classification

This model is a **Hybrid Vision Transformer (ViT) with EfficientNet B0** as the backbone. It is trained to classify the severity of **Diabetic Retinopathy** into different stages.

## 📌 Model Overview
- **Backbone**: EfficientNet B0 (Feature Extractor)
- **Head**: Vision Transformer (ViT) for Classification
- **Input Size**: 224x224 (RGB Images)
- **Output Classes**:
  - 0: No Diabetic Retinopathy
  - 1: Mild
  - 2: Moderate
  - 3: Severe
  - 4: Proliferative Diabetic Retinopathy

---

## 🚀 How to Use This Model

### **1️⃣  Download the Model**
Make sure you have **PyTorch** and **Torchvision** installed:
Clone the repository and navigate to it:
```bash
!git clone https://huggingface.co/PavanKumarAmbadapudi/DiabeticRetinopathy_Hybrid-ViT
cd DiabeticRetinopathy_Hybrid-ViT
```
Or manually download the files:
   
Hybrid_ViT.pth, model.py

### **2️⃣ Load the Model in Python** 
```python
import torch
from model import CNNViT 
model = CNNViT(num_classes=5)  
model.load_state_dict(torch.load("Hybrid_ViT.pth", map_location=torch.device('cpu')))
model.eval()
```

### **3️⃣ Perform Inference** 

To make predictions on an image:
```python
from PIL import Image
import torchvision.transforms as transforms

def map_prediction(prediction):
    mapping = {
        0: "No DR",
        1: "Mild",
        2: "Moderate",
        3: "Severe",
        4: "Proliferative DR"
    }
    return mapping.get(prediction, "Unknown")

image_path = 'Path_to_Your_Image'  

def getTransformations(image_path):
    transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Use RGB mean and std
    ])
    image = Image.open(image_path).convert("RGB")
    return transform(image).unsqueeze(0)

image_tensor = getTransformations(image_path)

def predict_model_Hybrid(model, image_tensor):
        with torch.no_grad():
            outputs = model(image_tensor)
            probabilities = torch.softmax(outputs, dim=1)
            predicted_classes = probabilities.argmax(dim=1).item()
            confidences = probabilities.max(dim=1).values.item()
        
        model_predictions =  {"label": map_prediction(predicted_classes), "confidence": confidences}
        return model_predictions
print("Hybrid ViT ", predict_model_Hybrid(model, image_tensor))
```

## 📬 Contact
For any queries, reach out to me at:
📧 [email protected]