dheeren-tejani commited on
Commit
4f95b95
·
1 Parent(s): f93191d

Initial commit

Browse files
Files changed (1) hide show
  1. README.md +183 -9
README.md CHANGED
@@ -1,4 +1,3 @@
1
- <<<<<<< HEAD
2
  ---
3
  license: apache-2.0
4
  language: en
@@ -13,9 +12,9 @@ tags:
13
  datasets:
14
  - aptos2019-blindness-detection
15
  widget:
16
- - src: model/gradcam_visualizations/gradcam_sample_003.png
17
  example_title: No DR Example
18
- - src: model/gradcam_visualizations/gradcam_sample_007.png
19
  example_title: Severe DR Example
20
  ---
21
 
@@ -41,9 +40,184 @@ The model can be easily loaded from Hugging Face Hub for inference.
41
 
42
  ```bash
43
  # Install required libraries
44
- pip install torch torchvision timm albumentations huggingface-hub numpy<2.0 pillow opencv-python
45
- =======
46
- ---
47
- license: apache-2.0
48
- ---
49
- >>>>>>> 96df5800dd01e6edcea4c1417179b056d6e3e22d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  language: en
 
12
  datasets:
13
  - aptos2019-blindness-detection
14
  widget:
15
+ - src: gradcam_visualizations/gradcam_sample_003.png
16
  example_title: No DR Example
17
+ - src: gradcam_visualizations/gradcam_sample_007.png
18
  example_title: Severe DR Example
19
  ---
20
 
 
40
 
41
  ```bash
42
  # Install required libraries
43
+ pip install torch torchvision timm albumentations huggingface-hub numpy pillow opencv-python
44
+ ```
45
+
46
+ ```python
47
+ import torch
48
+ import torch.nn as nn
49
+ import torch.nn.functional as F
50
+ import timm
51
+ from PIL import Image
52
+ import numpy as np
53
+ import albumentations as A
54
+ from albumentations.pytorch import ToTensorV2
55
+ from huggingface_hub import hf_hub_download
56
+
57
+ # Define the model architecture
58
+ class MultiTaskDRModel(nn.Module):
59
+ def __init__(self, model_name='efficientnet_b3', num_classes=5,
60
+ num_lesion_types=5, num_regions=5, pretrained=False):
61
+ super(MultiTaskDRModel, self).__init__()
62
+ self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
63
+ self.feature_dim = self.backbone.num_features
64
+
65
+ self.attention = nn.Sequential(
66
+ nn.AdaptiveAvgPool2d(1), nn.Flatten(),
67
+ nn.Linear(self.feature_dim, self.feature_dim // 8), nn.ReLU(inplace=True),
68
+ nn.Linear(self.feature_dim // 8, self.feature_dim), nn.Sigmoid()
69
+ )
70
+
71
+ self.feature_norm = nn.BatchNorm1d(self.feature_dim)
72
+ self.dropout = nn.Dropout(0.4)
73
+
74
+ self.severity_classifier = nn.Sequential(
75
+ nn.Linear(self.feature_dim, self.feature_dim // 2), nn.ReLU(inplace=True),
76
+ nn.Dropout(0.2), nn.Linear(self.feature_dim // 2, num_classes)
77
+ )
78
+
79
+ self.lesion_detector = nn.Sequential(
80
+ nn.Linear(self.feature_dim, self.feature_dim // 4), nn.ReLU(inplace=True),
81
+ nn.Dropout(0.2), nn.Linear(self.feature_dim // 4, num_lesion_types)
82
+ )
83
+
84
+ self.region_predictor = nn.Sequential(
85
+ nn.Linear(self.feature_dim, self.feature_dim // 4), nn.ReLU(inplace=True),
86
+ nn.Dropout(0.2), nn.Linear(self.feature_dim // 4, num_regions)
87
+ )
88
+
89
+ def forward(self, x):
90
+ features = self.backbone.forward_features(x)
91
+ pooled_features = F.adaptive_avg_pool2d(features, 1).flatten(1)
92
+ attention_weights = self.attention(pooled_features.unsqueeze(-1).unsqueeze(-1))
93
+ features = pooled_features * attention_weights
94
+ features = self.feature_norm(features)
95
+ features = self.dropout(features)
96
+
97
+ severity_logits = self.severity_classifier(features)
98
+ lesion_logits = self.lesion_detector(features)
99
+ region_logits = self.region_predictor(features)
100
+
101
+ return {
102
+ 'severity': severity_logits,
103
+ 'lesions': lesion_logits,
104
+ 'regions': region_logits,
105
+ 'features': features
106
+ }
107
+
108
+ # Load the model
109
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
110
+ model = MultiTaskDRModel()
111
+
112
+ # Download and load the checkpoint
113
+ model_path = hf_hub_download(
114
+ repo_id="dheeren-tejani/DiabeticRetinpathyClassifier",
115
+ filename="best_model_v2.pth"
116
+ )
117
+ checkpoint = torch.load(model_path, map_location=device, weights_only=False)
118
+ model.load_state_dict(checkpoint['model_state_dict'])
119
+ model.to(device)
120
+ model.eval()
121
+
122
+ print("Model loaded successfully!")
123
+
124
+ # Preprocessing function
125
+ def preprocess_image(image_path):
126
+ transforms = A.Compose([
127
+ A.Resize(512, 512),
128
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
129
+ ToTensorV2(),
130
+ ])
131
+ image = np.array(Image.open(image_path).convert("RGB"))
132
+ image_tensor = transforms(image=image)['image'].unsqueeze(0)
133
+ return image_tensor
134
+
135
+ # Example inference
136
+ def predict_dr_severity(image_path):
137
+ image_tensor = preprocess_image(image_path).to(device)
138
+
139
+ with torch.no_grad():
140
+ outputs = model(image_tensor)
141
+
142
+ # Get severity prediction
143
+ severity_probs = torch.softmax(outputs['severity'], dim=1)
144
+ predicted_class = torch.argmax(severity_probs, dim=1).item()
145
+ confidence = severity_probs[0, predicted_class].item()
146
+
147
+ severity_labels = {
148
+ 0: "No DR",
149
+ 1: "Mild DR",
150
+ 2: "Moderate DR",
151
+ 3: "Severe DR",
152
+ 4: "Proliferative DR"
153
+ }
154
+
155
+ return {
156
+ 'predicted_severity': severity_labels[predicted_class],
157
+ 'confidence': confidence,
158
+ 'all_probabilities': severity_probs[0].cpu().numpy()
159
+ }
160
+
161
+ # Example usage
162
+ # result = predict_dr_severity("path/to/your/fundus_image.jpg")
163
+ # print(f"Predicted: {result['predicted_severity']} (Confidence: {result['confidence']:.3f})")
164
+ ```
165
+
166
+ ## Training Details
167
+
168
+ ### V2 Improvements
169
+ This model (V2) was specifically designed to address the shortcomings of a baseline model (V1) that struggled with severe-stage DR detection:
170
+
171
+ - **Higher Resolution:** Increased from 224×224 to 512×512 to capture finer pathological details
172
+ - **Class Balancing:** Implemented WeightedRandomSampler to oversample rare minority classes (Severe and Proliferative DR)
173
+ - **Focal Loss:** Replaced standard Cross-Entropy with Focal Loss (γ=2.0) to focus on hard-to-classify examples
174
+ - **Focused Training:** Set auxiliary task weights to zero, dedicating full model capacity to severity classification
175
+
176
+ ### Hyperparameters
177
+ - **Optimizer:** AdamW
178
+ - **Learning Rate:** 1e-4
179
+ - **Scheduler:** CosineAnnealingWarmRestarts (T_MAX=10)
180
+ - **Batch Size:** 16
181
+ - **Epochs:** 17 (Early stopping)
182
+ - **Image Size:** 512×512
183
+
184
+ ## Performance
185
+
186
+ The model was evaluated on a held-out validation set of 735 images:
187
+
188
+ | Metric | Score |
189
+ |--------|-------|
190
+ | **Quadratic Weighted Kappa (QWK)** | **0.796** |
191
+ | Accuracy | 65.0% |
192
+ | F1-Score (Weighted) | 66.3% |
193
+ | F1-Score (Macro) | 53.5% |
194
+
195
+ ### Key Achievement
196
+ The V2 model achieved a **+3.5% improvement in QWK** over the V1 baseline (0.761), indicating it makes "smarter" errors that are more aligned with clinical judgment, despite lower overall accuracy. This trade-off prioritizes clinically relevant performance over naive accuracy.
197
+
198
+ ## Limitations
199
+
200
+ ⚠️ **Important Disclaimers:**
201
+ - This model was trained on a single public dataset and may not generalize to different clinical settings, camera types, or patient demographics
202
+ - The dataset may contain inherent demographic biases
203
+ - **This is NOT a medical device** and should not be used for actual clinical diagnosis
204
+ - Always consult qualified healthcare professionals for medical decisions
205
+
206
+ ## Citation
207
+
208
+ If you use this model in your research, please cite:
209
+
210
+ ```bibtex
211
+ @misc{dheerentejani2025dr,
212
+ author = {Dheeren Tejani},
213
+ title = {Diabetic Retinopathy Grading Model V2},
214
+ year = {2025},
215
+ publisher = {Hugging Face},
216
+ journal = {Hugging Face Model Hub},
217
+ howpublished = {\url{https://huggingface.co/dheeren-tejani/DiabeticRetinpathyClassifier}},
218
+ }
219
+ ```
220
+
221
+ ## License
222
+
223
+ This model is released under the Apache 2.0 License.