π₯ Diabetic Retinopathy Severity Classification
This model is a Hybrid Vision Transformer (ViT) with EfficientNet B0 as the backbone. It is trained to classify the severity of Diabetic Retinopathy into different stages.
π Model Overview
- Backbone: EfficientNet B0 (Feature Extractor)
- Head: Vision Transformer (ViT) for Classification
- Input Size: 224x224 (RGB Images)
- Output Classes:
- 0: No Diabetic Retinopathy
- 1: Mild
- 2: Moderate
- 3: Severe
- 4: Proliferative Diabetic Retinopathy
π How to Use This Model
1οΈβ£ Download the Model
Make sure you have PyTorch and Torchvision installed: Clone the repository and navigate to it:
!git clone https://huggingface.co/PavanKumarAmbadapudi/DiabeticRetinopathy_Hybrid-ViT
cd DiabeticRetinopathy_Hybrid-ViT
Or manually download the files:
Hybrid_ViT.pth, model.py
2οΈβ£ Load the Model in Python
import torch
from model import CNNViT
model = CNNViT(num_classes=5)
model.load_state_dict(torch.load("Hybrid_ViT.pth", map_location=torch.device('cpu')))
model.eval()
3οΈβ£ Perform Inference
To make predictions on an image:
from PIL import Image
import torchvision.transforms as transforms
def map_prediction(prediction):
mapping = {
0: "No DR",
1: "Mild",
2: "Moderate",
3: "Severe",
4: "Proliferative DR"
}
return mapping.get(prediction, "Unknown")
image_path = 'Path_to_Your_Image'
def getTransformations(image_path):
transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Use RGB mean and std
])
image = Image.open(image_path).convert("RGB")
return transform(image).unsqueeze(0)
image_tensor = getTransformations(image_path)
def predict_model_Hybrid(model, image_tensor):
with torch.no_grad():
outputs = model(image_tensor)
probabilities = torch.softmax(outputs, dim=1)
predicted_classes = probabilities.argmax(dim=1).item()
confidences = probabilities.max(dim=1).values.item()
model_predictions = {"label": map_prediction(predicted_classes), "confidence": confidences}
return model_predictions
print("Hybrid ViT ", predict_model_Hybrid(model, image_tensor))
π¬ Contact
For any queries, reach out to me at: π§ [email protected]
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
π
Ask for provider support