YashikaNagpal commited on
Commit
7aebb7e
·
verified ·
1 Parent(s): d7e1949

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +122 -0
README.md ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Model Card: Fine-Tuned MobileNetV2 for Skin Lesion Classification
3
+ # Model Details
4
+ **Model Name:** Fine-Tuned MobileNetV2 for Skin Lesion Classification
5
+ **Base Model:** google/mobilenet_v2_1.0_224 (pretrained on ImageNet)
6
+ **Dataset:** marmal88/skin_cancer
7
+ **Quantization:** Available as an optional FP16 version for optimized inference
8
+ **Training Device:** CUDA (GPU, 12 GB)
9
+
10
+ # Dataset Information
11
+ ```
12
+ Dataset Structure
13
+
14
+ DatasetDict({
15
+ train: Dataset({
16
+ features: ['image', 'image_id', 'lesion_id', 'dx', 'dx_type', 'age', 'sex', 'localization'],
17
+ num_rows: 9577
18
+ })
19
+ validation: Dataset({
20
+ features: ['image', 'image_id', 'lesion_id', 'dx', 'dx_type', 'age', 'sex', 'localization'],
21
+ num_rows: 2492
22
+ })
23
+ test: Dataset({
24
+ features: ['image', 'image_id', 'lesion_id', 'dx', 'dx_type', 'age', 'sex', 'localization'],
25
+ num_rows: 1285
26
+ })
27
+ })
28
+ Available Splits
29
+ Train: 9,577 examples
30
+ Validation: 2,492 examples
31
+ Test: 1,285 examples
32
+ ```
33
+ # Feature Representation
34
+ - image: RGB image (originally 600x450, resized to 224x224 during preprocessing)
35
+ - image_id: Unique identifier (e.g., ISIC_0024329)
36
+ - lesion_id: Lesion identifier (e.g., HAM_0002954)
37
+ - dx: Diagnosis label (e.g., melanoma, actinic_keratoses)
38
+ - dx_type: Diagnosis method (e.g., histo for histopathology)
39
+ - age: Patient age (float, e.g., 75.0)
40
+ - sex: Patient gender (e.g., female)
41
+ - localization: Body location (e.g., lower extremity)
42
+ # Note: Only image and dx (converted to integer labels) were used for training; other features were dropped during preprocessing.
43
+
44
+ # Training Details
45
+ - Number of Classes: 7
46
+ - Class Names: actinic_keratoses, basal_cell_carcinoma, benign_keratosis-like_lesions, dermatofibroma, melanocytic_Nevi, melanoma, vascular_lesions
47
+ - Training Process: Fine-tuned for 5 epochs (initially planned for 10, reduced to 5)
48
+ - Learning Rate: 0.001 (Adam optimizer, fine-tuning all layers)
49
+ - Batch Size: 32 (suitable for a 12 GB GPU)
50
+ - Preprocessing: Images resized to 224x224, normalized with mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] (ImageNet stats)
51
+ - Performance Metrics
52
+ - Epochs: 5
53
+ - Training Loss: [To be filled after training output]
54
+ - Validation Loss: [To be filled after training output]
55
+ - Accuracy: [To be filled after training output]
56
+ - F1 Score: Not computed (can be added with additional evaluation)
57
+ # Note: Performance metrics depend on your training output. Please provide the training log (loss/accuracy per epoch) to complete this section.
58
+
59
+ Inference Example
60
+ ```python
61
+ import torch
62
+ from torchvision import transforms
63
+ from PIL import Image
64
+ import torch.nn.functional as F
65
+ import json
66
+
67
+ # Preprocessing (matches training)
68
+ preprocess = transforms.Compose([
69
+ transforms.Resize((224, 224)),
70
+ transforms.ToTensor(),
71
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
72
+ ])
73
+
74
+ # Load model and labels
75
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
+ model_path = "skin_cancer_model_fp16/mobilenetv2_skin_cancer_fp16.pt"
77
+ model = torch.load(model_path, map_location=device)
78
+ model = model.to(device)
79
+ model.eval()
80
+
81
+ with open("skin_cancer_model_fp16/labels.json", 'r') as f:
82
+ label_mapping = json.load(f)
83
+ class_names = list(label_mapping.keys())
84
+
85
+ # Inference function
86
+ def predict_image(image_path, model, preprocess, device, class_names):
87
+ image = Image.open(image_path).convert('RGB')
88
+ image_tensor = preprocess(image).unsqueeze(0).half() # FP16
89
+ image_tensor = image_tensor.to(device)
90
+
91
+ with torch.no_grad():
92
+ outputs = model(image_tensor)
93
+ probabilities = F.softmax(outputs, dim=1)
94
+ confidence, predicted = torch.max(probabilities, 1)
95
+ predicted_class = class_names[predicted.item()]
96
+ confidence_score = confidence.item() * 100
97
+
98
+ return predicted_class, confidence_score
99
+
100
+ # Example usage
101
+ if __name__ == "__main__":
102
+ image_path = "C:/path/to/your/image.jpg" # Replace with your image path
103
+ predicted_class, confidence = predict_image(image_path, model, preprocess, device, class_names)
104
+ print(f"Predicted Class: {predicted_class}")
105
+ print(f"Confidence: {confidence:.2f}%")
106
+
107
+ ```
108
+ # Quantization & Optimization
109
+ **Quantization:** Optional FP16 version created using PyTorch’s .half() for faster inference and reduced memory footprint (~50% size reduction).
110
+ **Optimized:** Suitable for deployment on GPU-enabled devices (e.g., CUDA with 12 GB VRAM).
111
+
112
+ # Usage
113
+ - Input: RGB images (any size, resized to 224x224 during preprocessing)
114
+ - Output: Predicted skin lesion class (one of 7) with confidence probability
115
+ # Limitations
116
+ - Generalization: Trained on the marmal88/skin_cancer dataset, which may not fully represent all real-world skin lesions (e.g., varying lighting, angles, or skin types).
117
+ - Dataset Bias: Performance may vary depending on dataset diversity (e.g., age, sex, localization not used in training).
118
+ - Accuracy: Limited to 5 epochs; further training or larger datasets might improve results.
119
+ # Future Improvements
120
+ - Data Augmentation: Add rotation, flipping, or color jittering to enhance robustness.
121
+ - Larger Dataset: Incorporate additional skin cancer datasets (e.g., ISIC Archive) for better coverage.
122
+ - Model Tuning: Experiment with freezing feature layers or adjusting learning rate for better convergence.