PavanKumarAmbadapudi's picture
Update README.md
e85b01e verified
metadata
library_name: pytorch
tags:
  - classification
  - diabetic-retinopathy
  - hybrid vit
  - efficientnet
  - deep-learning
  - medical-ai
model_name: Hybrid_ViT
model_description: >-
  A Hybrid Vision Transformer (ViT) with EfficientNetB0 as the backbone, trained
  to classify Diabetic Retinopathy severity into 5 categories.
dataset:
  name: APTOS 2019 Blindness Detection
  url: https://www.kaggle.com/competitions/aptos2019-blindness-detection
  description: >-
    A dataset containing retinal fundus images for classifying Diabetic
    Retinopathy severity.
metrics:
  accuracy: 85.4
  loss: 0.23
  validation_accuracy: 84.5
  validation_loss: 0.25
training:
  epochs: 5
  batch_size: 32
  optimizer: Adam
  learning_rate: 0.0001
  scheduler:
    type: StepLR
    step_size: 10
    gamma: 0.5
  loss_function: CrossEntropyLoss
  k_folds: 5
  device: CUDA
model_architecture:
  backbone: EfficientNetB0
  transformer: Vision Transformer (ViT)
  num_classes: 5
  input_size: 224x224
  framework: PyTorch
author: Pavan Kumar Ambadapudi>
repository: https://huggingface.co/PavanKumarAmbadapudi/DiabeticRetinopathy_Hybrid-ViT

πŸ₯ Diabetic Retinopathy Severity Classification

This model is a Hybrid Vision Transformer (ViT) with EfficientNet B0 as the backbone. It is trained to classify the severity of Diabetic Retinopathy into different stages.

πŸ“Œ Model Overview

  • Backbone: EfficientNet B0 (Feature Extractor)
  • Head: Vision Transformer (ViT) for Classification
  • Input Size: 224x224 (RGB Images)
  • Output Classes:
    • 0: No Diabetic Retinopathy
    • 1: Mild
    • 2: Moderate
    • 3: Severe
    • 4: Proliferative Diabetic Retinopathy

πŸš€ How to Use This Model

1️⃣ Download the Model

Make sure you have PyTorch and Torchvision installed: Clone the repository and navigate to it:

!git clone https://huggingface.co/PavanKumarAmbadapudi/DiabeticRetinopathy_Hybrid-ViT
cd DiabeticRetinopathy_Hybrid-ViT

Or manually download the files:

Hybrid_ViT.pth, model.py

2️⃣ Load the Model in Python

import torch
from model import CNNViT 
model = CNNViT(num_classes=5)  
model.load_state_dict(torch.load("Hybrid_ViT.pth", map_location=torch.device('cpu')))
model.eval()

3️⃣ Perform Inference

To make predictions on an image:

from PIL import Image
import torchvision.transforms as transforms

def map_prediction(prediction):
    mapping = {
        0: "No DR",
        1: "Mild",
        2: "Moderate",
        3: "Severe",
        4: "Proliferative DR"
    }
    return mapping.get(prediction, "Unknown")

image_path = 'Path_to_Your_Image'  

def getTransformations(image_path):
    transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Use RGB mean and std
    ])
    image = Image.open(image_path).convert("RGB")
    return transform(image).unsqueeze(0)

image_tensor = getTransformations(image_path)

def predict_model_Hybrid(model, image_tensor):
        with torch.no_grad():
            outputs = model(image_tensor)
            probabilities = torch.softmax(outputs, dim=1)
            predicted_classes = probabilities.argmax(dim=1).item()
            confidences = probabilities.max(dim=1).values.item()
        
        model_predictions =  {"label": map_prediction(predicted_classes), "confidence": confidences}
        return model_predictions
print("Hybrid ViT ", predict_model_Hybrid(model, image_tensor))

πŸ“¬ Contact

For any queries, reach out to me at: πŸ“§ [email protected]