Diabetic Retinopathy Detection Model (IDRiD)
This is a deep learning model trained to detect the severity of Diabetic Retinopathy from retinal fundus images. It classifies images into one of five grades (0-4), based on the model architecture and training script provided.
Model Details
- Model Type: Custom Deep Convolutional Neural Network (CNN) with ResNet-style residual blocks.
- Framework: TensorFlow / Keras.
- Input Shape:
(384, 384, 3)
(RGB images) - Output Classes: 5 (corresponding to retinopathy grades 0 through 4)
Dataset
The model was trained on the Indian Diabetic Retinopathy Image Dataset (IDRiD), which contains a large set of high-resolution fundus images with corresponding severity grades.
How to Use
You can load this model directly from the Hugging Face Hub using the huggingface_hub
library.
import tensorflow as tf
from huggingface_hub import from_pretrained_keras
import cv2
import numpy as np
# Load the model from the Hugging Face Hub
try:
model = from_pretrained_keras("Arko007/diabetic-retinopathy-v1")
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
def preprocess_image(image_path):
'''Your preprocessing function must match the one used in training.'''
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (384, 384))
img = img.astype(np.float32) / 255.0
return img
# Preprocess your image to a (384, 384, 3) tensor
# example_image = preprocess_image("path/to/your/image.jpg")
# The model expects a batch, so add a batch dimension
# example_image_batch = tf.expand_dims(example_image, axis=0)
# Get predictions
# predictions = model.predict(example_image_batch)
# predicted_class = tf.argmax(predictions, axis=1).numpy()[0]
# print(f"Predicted Retinopathy Grade: {predicted_class}")
Training Procedure
The model was trained using the train.py
script. Key aspects of the training include:
- Optimizer: Adam with a learning rate of
1e-4
. - Loss Function: Categorical Crossentropy.
- Callbacks:
ModelCheckpoint
to save the best model based on validation accuracy andReduceLROnPlateau
to adjust the learning rate during training.
- Downloads last month
- -