πŸ₯ Diabetic Retinopathy Severity Classification

This model is a Hybrid Vision Transformer (ViT) with EfficientNet B0 as the backbone. It is trained to classify the severity of Diabetic Retinopathy into different stages.

πŸ“Œ Model Overview

  • Backbone: EfficientNet B0 (Feature Extractor)
  • Head: Vision Transformer (ViT) for Classification
  • Input Size: 224x224 (RGB Images)
  • Output Classes:
    • 0: No Diabetic Retinopathy
    • 1: Mild
    • 2: Moderate
    • 3: Severe
    • 4: Proliferative Diabetic Retinopathy

πŸš€ How to Use This Model

1️⃣ Download the Model

Make sure you have PyTorch and Torchvision installed: Clone the repository and navigate to it:

!git clone https://huggingface.co/PavanKumarAmbadapudi/DiabeticRetinopathy_Hybrid-ViT
cd DiabeticRetinopathy_Hybrid-ViT

Or manually download the files:

Hybrid_ViT.pth, model.py

2️⃣ Load the Model in Python

import torch
from model import CNNViT 
model = CNNViT(num_classes=5)  
model.load_state_dict(torch.load("Hybrid_ViT.pth", map_location=torch.device('cpu')))
model.eval()

3️⃣ Perform Inference

To make predictions on an image:

from PIL import Image
import torchvision.transforms as transforms

def map_prediction(prediction):
    mapping = {
        0: "No DR",
        1: "Mild",
        2: "Moderate",
        3: "Severe",
        4: "Proliferative DR"
    }
    return mapping.get(prediction, "Unknown")

image_path = 'Path_to_Your_Image'  

def getTransformations(image_path):
    transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Use RGB mean and std
    ])
    image = Image.open(image_path).convert("RGB")
    return transform(image).unsqueeze(0)

image_tensor = getTransformations(image_path)

def predict_model_Hybrid(model, image_tensor):
        with torch.no_grad():
            outputs = model(image_tensor)
            probabilities = torch.softmax(outputs, dim=1)
            predicted_classes = probabilities.argmax(dim=1).item()
            confidences = probabilities.max(dim=1).values.item()
        
        model_predictions =  {"label": map_prediction(predicted_classes), "confidence": confidences}
        return model_predictions
print("Hybrid ViT ", predict_model_Hybrid(model, image_tensor))

πŸ“¬ Contact

For any queries, reach out to me at: πŸ“§ [email protected]

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support