Diabetic Retinopathy Detection Model (IDRiD)

This is a deep learning model trained to detect the severity of Diabetic Retinopathy from retinal fundus images. It classifies images into one of five grades (0-4), based on the model architecture and training script provided.

Model Details

  • Model Type: Custom Deep Convolutional Neural Network (CNN) with ResNet-style residual blocks.
  • Framework: TensorFlow / Keras.
  • Input Shape: (384, 384, 3) (RGB images)
  • Output Classes: 5 (corresponding to retinopathy grades 0 through 4)

Dataset

The model was trained on the Indian Diabetic Retinopathy Image Dataset (IDRiD), which contains a large set of high-resolution fundus images with corresponding severity grades.

How to Use

You can load this model directly from the Hugging Face Hub using the huggingface_hub library.

import tensorflow as tf
from huggingface_hub import from_pretrained_keras
import cv2
import numpy as np

# Load the model from the Hugging Face Hub
try:
    model = from_pretrained_keras("Arko007/diabetic-retinopathy-v1")
    print("Model loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")

def preprocess_image(image_path):
    '''Your preprocessing function must match the one used in training.'''
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (384, 384))
    img = img.astype(np.float32) / 255.0
    return img

# Preprocess your image to a (384, 384, 3) tensor
# example_image = preprocess_image("path/to/your/image.jpg")

# The model expects a batch, so add a batch dimension
# example_image_batch = tf.expand_dims(example_image, axis=0)

# Get predictions
# predictions = model.predict(example_image_batch)
# predicted_class = tf.argmax(predictions, axis=1).numpy()[0]

# print(f"Predicted Retinopathy Grade: {predicted_class}")

Training Procedure

The model was trained using the train.py script. Key aspects of the training include:

  • Optimizer: Adam with a learning rate of 1e-4.
  • Loss Function: Categorical Crossentropy.
  • Callbacks: ModelCheckpoint to save the best model based on validation accuracy and ReduceLROnPlateau to adjust the learning rate during training.
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Space using Arko007/diabetic-retinopathy-v1 1