--- # Model Card for Hybrid ViT - Diabetic Retinopathy Classification library_name: pytorch tags: - classification - diabetic-retinopathy - hybrid vit - efficientnet - deep-learning - medical-ai model_name: Hybrid_ViT model_description: "A Hybrid Vision Transformer (ViT) with EfficientNetB0 as the backbone, trained to classify Diabetic Retinopathy severity into 5 categories." dataset: name: "APTOS 2019 Blindness Detection" url: "https://www.kaggle.com/competitions/aptos2019-blindness-detection" description: "A dataset containing retinal fundus images for classifying Diabetic Retinopathy severity." metrics: accuracy: 85.4 loss: 0.23 validation_accuracy: 84.5 validation_loss: 0.25 training: epochs: 5 batch_size: 32 optimizer: Adam learning_rate: 1e-4 scheduler: type: StepLR step_size: 10 gamma: 0.5 loss_function: CrossEntropyLoss k_folds: 5 device: "CUDA" model_architecture: backbone: EfficientNetB0 transformer: Vision Transformer (ViT) num_classes: 5 input_size: 224x224 framework: PyTorch author: "Pavan Kumar Ambadapudi>" repository: "https://huggingface.co/PavanKumarAmbadapudi/DiabeticRetinopathy_Hybrid-ViT" --- # 🏥 Diabetic Retinopathy Severity Classification This model is a **Hybrid Vision Transformer (ViT) with EfficientNet B0** as the backbone. It is trained to classify the severity of **Diabetic Retinopathy** into different stages. ## 📌 Model Overview - **Backbone**: EfficientNet B0 (Feature Extractor) - **Head**: Vision Transformer (ViT) for Classification - **Input Size**: 224x224 (RGB Images) - **Output Classes**: - 0: No Diabetic Retinopathy - 1: Mild - 2: Moderate - 3: Severe - 4: Proliferative Diabetic Retinopathy --- ## 🚀 How to Use This Model ### **1️⃣ Download the Model** Make sure you have **PyTorch** and **Torchvision** installed: Clone the repository and navigate to it: ```bash !git clone https://huggingface.co/PavanKumarAmbadapudi/DiabeticRetinopathy_Hybrid-ViT cd DiabeticRetinopathy_Hybrid-ViT ``` Or manually download the files: Hybrid_ViT.pth, model.py ### **2️⃣ Load the Model in Python** ```python import torch from model import CNNViT model = CNNViT(num_classes=5) model.load_state_dict(torch.load("Hybrid_ViT.pth", map_location=torch.device('cpu'))) model.eval() ``` ### **3️⃣ Perform Inference** To make predictions on an image: ```python from PIL import Image import torchvision.transforms as transforms def map_prediction(prediction): mapping = { 0: "No DR", 1: "Mild", 2: "Moderate", 3: "Severe", 4: "Proliferative DR" } return mapping.get(prediction, "Unknown") image_path = 'Path_to_Your_Image' def getTransformations(image_path): transform = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Use RGB mean and std ]) image = Image.open(image_path).convert("RGB") return transform(image).unsqueeze(0) image_tensor = getTransformations(image_path) def predict_model_Hybrid(model, image_tensor): with torch.no_grad(): outputs = model(image_tensor) probabilities = torch.softmax(outputs, dim=1) predicted_classes = probabilities.argmax(dim=1).item() confidences = probabilities.max(dim=1).values.item() model_predictions = {"label": map_prediction(predicted_classes), "confidence": confidences} return model_predictions print("Hybrid ViT ", predict_model_Hybrid(model, image_tensor)) ``` ## 📬 Contact For any queries, reach out to me at: 📧 pavan.ambadapudi@gmail.com