Commit
·
4f95b95
1
Parent(s):
f93191d
Initial commit
Browse files
README.md
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
<<<<<<< HEAD
|
2 |
---
|
3 |
license: apache-2.0
|
4 |
language: en
|
@@ -13,9 +12,9 @@ tags:
|
|
13 |
datasets:
|
14 |
- aptos2019-blindness-detection
|
15 |
widget:
|
16 |
-
- src:
|
17 |
example_title: No DR Example
|
18 |
-
- src:
|
19 |
example_title: Severe DR Example
|
20 |
---
|
21 |
|
@@ -41,9 +40,184 @@ The model can be easily loaded from Hugging Face Hub for inference.
|
|
41 |
|
42 |
```bash
|
43 |
# Install required libraries
|
44 |
-
pip install torch torchvision timm albumentations huggingface-hub numpy
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
language: en
|
|
|
12 |
datasets:
|
13 |
- aptos2019-blindness-detection
|
14 |
widget:
|
15 |
+
- src: gradcam_visualizations/gradcam_sample_003.png
|
16 |
example_title: No DR Example
|
17 |
+
- src: gradcam_visualizations/gradcam_sample_007.png
|
18 |
example_title: Severe DR Example
|
19 |
---
|
20 |
|
|
|
40 |
|
41 |
```bash
|
42 |
# Install required libraries
|
43 |
+
pip install torch torchvision timm albumentations huggingface-hub numpy pillow opencv-python
|
44 |
+
```
|
45 |
+
|
46 |
+
```python
|
47 |
+
import torch
|
48 |
+
import torch.nn as nn
|
49 |
+
import torch.nn.functional as F
|
50 |
+
import timm
|
51 |
+
from PIL import Image
|
52 |
+
import numpy as np
|
53 |
+
import albumentations as A
|
54 |
+
from albumentations.pytorch import ToTensorV2
|
55 |
+
from huggingface_hub import hf_hub_download
|
56 |
+
|
57 |
+
# Define the model architecture
|
58 |
+
class MultiTaskDRModel(nn.Module):
|
59 |
+
def __init__(self, model_name='efficientnet_b3', num_classes=5,
|
60 |
+
num_lesion_types=5, num_regions=5, pretrained=False):
|
61 |
+
super(MultiTaskDRModel, self).__init__()
|
62 |
+
self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
|
63 |
+
self.feature_dim = self.backbone.num_features
|
64 |
+
|
65 |
+
self.attention = nn.Sequential(
|
66 |
+
nn.AdaptiveAvgPool2d(1), nn.Flatten(),
|
67 |
+
nn.Linear(self.feature_dim, self.feature_dim // 8), nn.ReLU(inplace=True),
|
68 |
+
nn.Linear(self.feature_dim // 8, self.feature_dim), nn.Sigmoid()
|
69 |
+
)
|
70 |
+
|
71 |
+
self.feature_norm = nn.BatchNorm1d(self.feature_dim)
|
72 |
+
self.dropout = nn.Dropout(0.4)
|
73 |
+
|
74 |
+
self.severity_classifier = nn.Sequential(
|
75 |
+
nn.Linear(self.feature_dim, self.feature_dim // 2), nn.ReLU(inplace=True),
|
76 |
+
nn.Dropout(0.2), nn.Linear(self.feature_dim // 2, num_classes)
|
77 |
+
)
|
78 |
+
|
79 |
+
self.lesion_detector = nn.Sequential(
|
80 |
+
nn.Linear(self.feature_dim, self.feature_dim // 4), nn.ReLU(inplace=True),
|
81 |
+
nn.Dropout(0.2), nn.Linear(self.feature_dim // 4, num_lesion_types)
|
82 |
+
)
|
83 |
+
|
84 |
+
self.region_predictor = nn.Sequential(
|
85 |
+
nn.Linear(self.feature_dim, self.feature_dim // 4), nn.ReLU(inplace=True),
|
86 |
+
nn.Dropout(0.2), nn.Linear(self.feature_dim // 4, num_regions)
|
87 |
+
)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
features = self.backbone.forward_features(x)
|
91 |
+
pooled_features = F.adaptive_avg_pool2d(features, 1).flatten(1)
|
92 |
+
attention_weights = self.attention(pooled_features.unsqueeze(-1).unsqueeze(-1))
|
93 |
+
features = pooled_features * attention_weights
|
94 |
+
features = self.feature_norm(features)
|
95 |
+
features = self.dropout(features)
|
96 |
+
|
97 |
+
severity_logits = self.severity_classifier(features)
|
98 |
+
lesion_logits = self.lesion_detector(features)
|
99 |
+
region_logits = self.region_predictor(features)
|
100 |
+
|
101 |
+
return {
|
102 |
+
'severity': severity_logits,
|
103 |
+
'lesions': lesion_logits,
|
104 |
+
'regions': region_logits,
|
105 |
+
'features': features
|
106 |
+
}
|
107 |
+
|
108 |
+
# Load the model
|
109 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
110 |
+
model = MultiTaskDRModel()
|
111 |
+
|
112 |
+
# Download and load the checkpoint
|
113 |
+
model_path = hf_hub_download(
|
114 |
+
repo_id="dheeren-tejani/DiabeticRetinpathyClassifier",
|
115 |
+
filename="best_model_v2.pth"
|
116 |
+
)
|
117 |
+
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
118 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
119 |
+
model.to(device)
|
120 |
+
model.eval()
|
121 |
+
|
122 |
+
print("Model loaded successfully!")
|
123 |
+
|
124 |
+
# Preprocessing function
|
125 |
+
def preprocess_image(image_path):
|
126 |
+
transforms = A.Compose([
|
127 |
+
A.Resize(512, 512),
|
128 |
+
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
129 |
+
ToTensorV2(),
|
130 |
+
])
|
131 |
+
image = np.array(Image.open(image_path).convert("RGB"))
|
132 |
+
image_tensor = transforms(image=image)['image'].unsqueeze(0)
|
133 |
+
return image_tensor
|
134 |
+
|
135 |
+
# Example inference
|
136 |
+
def predict_dr_severity(image_path):
|
137 |
+
image_tensor = preprocess_image(image_path).to(device)
|
138 |
+
|
139 |
+
with torch.no_grad():
|
140 |
+
outputs = model(image_tensor)
|
141 |
+
|
142 |
+
# Get severity prediction
|
143 |
+
severity_probs = torch.softmax(outputs['severity'], dim=1)
|
144 |
+
predicted_class = torch.argmax(severity_probs, dim=1).item()
|
145 |
+
confidence = severity_probs[0, predicted_class].item()
|
146 |
+
|
147 |
+
severity_labels = {
|
148 |
+
0: "No DR",
|
149 |
+
1: "Mild DR",
|
150 |
+
2: "Moderate DR",
|
151 |
+
3: "Severe DR",
|
152 |
+
4: "Proliferative DR"
|
153 |
+
}
|
154 |
+
|
155 |
+
return {
|
156 |
+
'predicted_severity': severity_labels[predicted_class],
|
157 |
+
'confidence': confidence,
|
158 |
+
'all_probabilities': severity_probs[0].cpu().numpy()
|
159 |
+
}
|
160 |
+
|
161 |
+
# Example usage
|
162 |
+
# result = predict_dr_severity("path/to/your/fundus_image.jpg")
|
163 |
+
# print(f"Predicted: {result['predicted_severity']} (Confidence: {result['confidence']:.3f})")
|
164 |
+
```
|
165 |
+
|
166 |
+
## Training Details
|
167 |
+
|
168 |
+
### V2 Improvements
|
169 |
+
This model (V2) was specifically designed to address the shortcomings of a baseline model (V1) that struggled with severe-stage DR detection:
|
170 |
+
|
171 |
+
- **Higher Resolution:** Increased from 224×224 to 512×512 to capture finer pathological details
|
172 |
+
- **Class Balancing:** Implemented WeightedRandomSampler to oversample rare minority classes (Severe and Proliferative DR)
|
173 |
+
- **Focal Loss:** Replaced standard Cross-Entropy with Focal Loss (γ=2.0) to focus on hard-to-classify examples
|
174 |
+
- **Focused Training:** Set auxiliary task weights to zero, dedicating full model capacity to severity classification
|
175 |
+
|
176 |
+
### Hyperparameters
|
177 |
+
- **Optimizer:** AdamW
|
178 |
+
- **Learning Rate:** 1e-4
|
179 |
+
- **Scheduler:** CosineAnnealingWarmRestarts (T_MAX=10)
|
180 |
+
- **Batch Size:** 16
|
181 |
+
- **Epochs:** 17 (Early stopping)
|
182 |
+
- **Image Size:** 512×512
|
183 |
+
|
184 |
+
## Performance
|
185 |
+
|
186 |
+
The model was evaluated on a held-out validation set of 735 images:
|
187 |
+
|
188 |
+
| Metric | Score |
|
189 |
+
|--------|-------|
|
190 |
+
| **Quadratic Weighted Kappa (QWK)** | **0.796** |
|
191 |
+
| Accuracy | 65.0% |
|
192 |
+
| F1-Score (Weighted) | 66.3% |
|
193 |
+
| F1-Score (Macro) | 53.5% |
|
194 |
+
|
195 |
+
### Key Achievement
|
196 |
+
The V2 model achieved a **+3.5% improvement in QWK** over the V1 baseline (0.761), indicating it makes "smarter" errors that are more aligned with clinical judgment, despite lower overall accuracy. This trade-off prioritizes clinically relevant performance over naive accuracy.
|
197 |
+
|
198 |
+
## Limitations
|
199 |
+
|
200 |
+
⚠️ **Important Disclaimers:**
|
201 |
+
- This model was trained on a single public dataset and may not generalize to different clinical settings, camera types, or patient demographics
|
202 |
+
- The dataset may contain inherent demographic biases
|
203 |
+
- **This is NOT a medical device** and should not be used for actual clinical diagnosis
|
204 |
+
- Always consult qualified healthcare professionals for medical decisions
|
205 |
+
|
206 |
+
## Citation
|
207 |
+
|
208 |
+
If you use this model in your research, please cite:
|
209 |
+
|
210 |
+
```bibtex
|
211 |
+
@misc{dheerentejani2025dr,
|
212 |
+
author = {Dheeren Tejani},
|
213 |
+
title = {Diabetic Retinopathy Grading Model V2},
|
214 |
+
year = {2025},
|
215 |
+
publisher = {Hugging Face},
|
216 |
+
journal = {Hugging Face Model Hub},
|
217 |
+
howpublished = {\url{https://huggingface.co/dheeren-tejani/DiabeticRetinpathyClassifier}},
|
218 |
+
}
|
219 |
+
```
|
220 |
+
|
221 |
+
## License
|
222 |
+
|
223 |
+
This model is released under the Apache 2.0 License.
|