PavanKumarAmbadapudi commited on
Commit
27dfecc
·
1 Parent(s): afd64a3

Successfully Added

Browse files
Files changed (3) hide show
  1. Hybrid_ViT.pth +3 -0
  2. README.md +103 -0
  3. model.py +64 -0
Hybrid_ViT.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00d90af42c936718b9339fdd5024a812ab6889f5e75f3f99a807497ea1f9d84c
3
+ size 77229102
README.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # 🏥 Diabetic Retinopathy Severity Classification
4
+
5
+ This model is a **Hybrid Vision Transformer (ViT) with EfficientNet B0** as the backbone. It is trained to classify the severity of **Diabetic Retinopathy** into different stages.
6
+
7
+ ## 📌 Model Overview
8
+ - **Backbone**: EfficientNet B0 (Feature Extractor)
9
+ - **Head**: Vision Transformer (ViT) for Classification
10
+ - **Input Size**: 224x224 (RGB Images)
11
+ - **Output Classes**:
12
+ - 0: No Diabetic Retinopathy
13
+ - 1: Mild
14
+ - 2: Moderate
15
+ - 3: Severe
16
+ - 4: Proliferative Diabetic Retinopathy
17
+
18
+ ---
19
+
20
+ ## 🚀 How to Use This Model
21
+
22
+ ### **1️⃣ Download the Model**
23
+ Make sure you have **PyTorch** and **Torchvision** installed:
24
+ Clone the repository and navigate to it:
25
+ ```bash
26
+ !git clone https://huggingface.co/PavanKumarAmbadapudi/DiabeticRetinopathy_Hybrid-ViT
27
+ cd DiabeticRetinopathy_Hybrid-ViT
28
+ ```
29
+ Or manually download the files:
30
+
31
+ Hybrid_ViT.pth, model.py
32
+
33
+ ### **2️⃣ Load the Model in Python**
34
+ ```python
35
+ import torch
36
+ from model import CNNViT
37
+ model = CNNViT(num_classes=5)
38
+ model.load_state_dict(torch.load("Hybrid_ViT.pth", map_location=torch.device('cpu')))
39
+ model.eval()
40
+ ```
41
+
42
+ ### **3️⃣ Perform Inference**
43
+
44
+ To make predictions on an image:
45
+ ```python
46
+ from PIL import Image
47
+ import torchvision.transforms as transforms
48
+
49
+ def map_prediction(prediction):
50
+ mapping = {
51
+ 0: "No DR",
52
+ 1: "Mild",
53
+ 2: "Moderate",
54
+ 3: "Severe",
55
+ 4: "Proliferative DR"
56
+ }
57
+ return mapping.get(prediction, "Unknown")
58
+
59
+ image_path = 'Path_to_Your_Image'
60
+
61
+ def getTransformations(image_path):
62
+ transform = transforms.Compose([
63
+ transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
64
+ transforms.RandomHorizontalFlip(),
65
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
66
+ transforms.ToTensor(),
67
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Use RGB mean and std
68
+ ])
69
+ image = Image.open(image_path).convert("RGB")
70
+ return transform(image).unsqueeze(0)
71
+
72
+ image_tensor = getTransformations(image_path)
73
+
74
+ def predict_model_Hybrid(model, image_tensor):
75
+ with torch.no_grad():
76
+ outputs = model(image_tensor)
77
+ probabilities = torch.softmax(outputs, dim=1)
78
+ predicted_classes = probabilities.argmax(dim=1).item()
79
+ confidences = probabilities.max(dim=1).values.item()
80
+
81
+ model_predictions = {"label": map_prediction(predicted_classes), "confidence": confidences}
82
+ return model_predictions
83
+ print("Hybrid ViT ", predict_model_Hybrid(model, image_tensor))
84
+ ```
85
+ ## 📊 Training Details
86
+ This model was trained on **APTOS 2019 Blindness Detection** dataset using **5-Fold Cross-Validation** to ensure better generalization. The training process involved EfficientNet B0 as a feature extractor combined with a Vision Transformer (ViT) classification head.
87
+
88
+ ## 🛠️ Hyperparameters
89
+ | Parameter | Value |
90
+ |--------------|-------|
91
+ | **Image Size** | 224x224 |
92
+ | **Batch Size** | 32 |
93
+ | **Epochs** | 5 |
94
+ | **K-Folds** | 5 |
95
+ | **Learning Rate** | 1e-4 |
96
+ | **Optimizer** | Adam |
97
+ | **Scheduler** | StepLR (Step=10, Gamma=0.5) |
98
+ | **Loss Function** | CrossEntropyLoss |
99
+ | **Device** | `CUDA` (if available) |
100
+
101
+ ## 📬 Contact
102
+ For any queries, reach out to me at:
103
model.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+
5
+ class TransformerBlock(nn.Module):
6
+ def __init__(self, embed_dim=1280, num_heads=8, ff_dim=3072, dropout=0.1):
7
+ super(TransformerBlock, self).__init__()
8
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
9
+ self.norm1 = nn.LayerNorm(embed_dim)
10
+ self.norm2 = nn.LayerNorm(embed_dim)
11
+ self.ffn = nn.Sequential(
12
+ nn.Linear(embed_dim, ff_dim),
13
+ nn.GELU(),
14
+ nn.Linear(ff_dim, embed_dim),
15
+ nn.Dropout(dropout)
16
+ )
17
+
18
+ def forward(self, x):
19
+ x = x.unsqueeze(1)
20
+ x = x.permute(1, 0, 2)
21
+
22
+ attn_output, _ = self.attn(x, x, x)
23
+ x = self.norm1(x + attn_output)
24
+
25
+ ffn_output = self.ffn(x)
26
+ x = self.norm2(x + ffn_output)
27
+
28
+ x = x.permute(1, 0, 2)
29
+ return x
30
+
31
+ class EfficientNetBackbone(nn.Module):
32
+ def __init__(self):
33
+ super(EfficientNetBackbone, self).__init__()
34
+ self.model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=0, global_pool='avg')
35
+ self.out_features = 1280
36
+
37
+ def forward(self, x):
38
+ x = self.model(x)
39
+ return x
40
+
41
+ class CNNViT(nn.Module):
42
+ def __init__(self, num_classes=5):
43
+ super(CNNViT, self).__init__()
44
+ self.cnn_backbone = EfficientNetBackbone()
45
+ self.transformer = TransformerBlock(embed_dim=1280, num_heads=8, ff_dim=3072)
46
+
47
+ self.fc = nn.Sequential(
48
+ nn.Linear(1280, 512),
49
+ nn.ReLU(),
50
+ nn.Dropout(0.3),
51
+ nn.Linear(512, 256),
52
+ nn.ReLU(),
53
+ nn.Dropout(0.3),
54
+ nn.Linear(256, num_classes)
55
+ )
56
+
57
+ def forward(self, x):
58
+ x = self.cnn_backbone(x)
59
+ x = self.transformer(x)
60
+ x = x.squeeze(1)
61
+ x = self.fc(x)
62
+ return x
63
+
64
+ model_Hybrid = CNNViT(num_classes=5)