Amarthya7 commited on
Commit
05e3595
·
verified ·
1 Parent(s): d117490

Update mediSync/models/image_analyzer.py

Browse files
Files changed (1) hide show
  1. mediSync/models/image_analyzer.py +194 -194
mediSync/models/image_analyzer.py CHANGED
@@ -1,194 +1,194 @@
1
- import logging
2
- import os
3
-
4
- import torch
5
- from PIL import Image
6
- from transformers import AutoFeatureExtractor, AutoModelForImageClassification
7
-
8
-
9
- class XRayImageAnalyzer:
10
- """
11
- A class for analyzing medical X-ray images using pre-trained models from Hugging Face.
12
-
13
- This analyzer uses the DeiT (Data-efficient image Transformers) model fine-tuned
14
- on chest X-ray images to detect abnormalities.
15
- """
16
-
17
- def __init__(
18
- self, model_name="facebook/deit-base-patch16-224-medical-cxr", device=None
19
- ):
20
- """
21
- Initialize the X-ray image analyzer with a specific pre-trained model.
22
-
23
- Args:
24
- model_name (str): The Hugging Face model name to use
25
- device (str, optional): Device to run the model on ('cuda' or 'cpu')
26
- """
27
- self.logger = logging.getLogger(__name__)
28
-
29
- # Determine device (CPU or GPU)
30
- if device is None:
31
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
32
- else:
33
- self.device = device
34
-
35
- self.logger.info(f"Using device: {self.device}")
36
-
37
- # Load model and feature extractor
38
- try:
39
- self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
40
- self.model = AutoModelForImageClassification.from_pretrained(model_name)
41
- self.model.to(self.device)
42
- self.model.eval() # Set to evaluation mode
43
- self.logger.info(f"Successfully loaded model: {model_name}")
44
-
45
- # Map labels to more informative descriptions
46
- self.labels = self.model.config.id2label
47
-
48
- except Exception as e:
49
- self.logger.error(f"Failed to load model: {e}")
50
- raise
51
-
52
- def preprocess_image(self, image_path):
53
- """
54
- Preprocess an X-ray image for model input.
55
-
56
- Args:
57
- image_path (str or PIL.Image): Path to image or PIL Image object
58
-
59
- Returns:
60
- dict: Processed inputs ready for the model
61
- """
62
- try:
63
- # Load image if path is provided
64
- if isinstance(image_path, str):
65
- if not os.path.exists(image_path):
66
- raise FileNotFoundError(f"Image file not found: {image_path}")
67
- image = Image.open(image_path).convert("RGB")
68
- else:
69
- # Assume it's already a PIL Image
70
- image = image_path.convert("RGB")
71
-
72
- # Apply feature extraction
73
- inputs = self.feature_extractor(images=image, return_tensors="pt")
74
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
75
-
76
- return inputs, image
77
-
78
- except Exception as e:
79
- self.logger.error(f"Error in preprocessing image: {e}")
80
- raise
81
-
82
- def analyze(self, image_path, threshold=0.5):
83
- """
84
- Analyze an X-ray image and detect abnormalities.
85
-
86
- Args:
87
- image_path (str or PIL.Image): Path to the X-ray image or PIL Image object
88
- threshold (float): Classification threshold for positive findings
89
-
90
- Returns:
91
- dict: Analysis results including:
92
- - predictions: List of (label, probability) tuples
93
- - primary_finding: The most likely abnormality
94
- - has_abnormality: Boolean indicating if abnormalities were detected
95
- - confidence: Confidence score for the primary finding
96
- """
97
- try:
98
- # Preprocess the image
99
- inputs, original_image = self.preprocess_image(image_path)
100
-
101
- # Run inference
102
- with torch.no_grad():
103
- outputs = self.model(**inputs)
104
-
105
- # Process predictions
106
- probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
107
- probabilities = probabilities.cpu().numpy()
108
-
109
- # Get predictions sorted by probability
110
- predictions = []
111
- for i, p in enumerate(probabilities):
112
- label = self.labels[i]
113
- predictions.append((label, float(p)))
114
-
115
- # Sort by probability (descending)
116
- predictions.sort(key=lambda x: x[1], reverse=True)
117
-
118
- # Determine if there's an abnormality and the primary finding
119
- normal_idx = [
120
- i
121
- for i, (label, _) in enumerate(predictions)
122
- if label.lower() == "normal" or label.lower() == "no finding"
123
- ]
124
-
125
- if normal_idx and predictions[normal_idx[0]][1] > threshold:
126
- has_abnormality = False
127
- primary_finding = "No abnormalities detected"
128
- confidence = predictions[normal_idx[0]][1]
129
- else:
130
- has_abnormality = True
131
- primary_finding = predictions[0][0]
132
- confidence = predictions[0][1]
133
-
134
- return {
135
- "predictions": predictions,
136
- "primary_finding": primary_finding,
137
- "has_abnormality": has_abnormality,
138
- "confidence": confidence,
139
- }
140
-
141
- except Exception as e:
142
- self.logger.error(f"Error analyzing image: {e}")
143
- raise
144
-
145
- def get_explanation(self, results):
146
- """
147
- Generate a human-readable explanation of the analysis results.
148
-
149
- Args:
150
- results (dict): The results returned by the analyze method
151
-
152
- Returns:
153
- str: A text explanation of the findings
154
- """
155
- if not results["has_abnormality"]:
156
- explanation = (
157
- f"The X-ray appears normal with {results['confidence']:.1%} confidence."
158
- )
159
- else:
160
- explanation = (
161
- f"The primary finding is {results['primary_finding']} "
162
- f"with {results['confidence']:.1%} confidence.\n\n"
163
- f"Other potential findings include:\n"
164
- )
165
-
166
- # Add top 3 other findings (skipping the first one which is primary)
167
- for label, prob in results["predictions"][1:4]:
168
- if prob > 0.05: # Only include if probability > 5%
169
- explanation += f"- {label}: {prob:.1%}\n"
170
-
171
- return explanation
172
-
173
-
174
- # Example usage
175
- if __name__ == "__main__":
176
- # Set up logging
177
- logging.basicConfig(level=logging.INFO)
178
-
179
- # Test on a sample image if available
180
- analyzer = XRayImageAnalyzer()
181
-
182
- # Check if sample data directory exists
183
- sample_dir = "../data/sample"
184
- if os.path.exists(sample_dir) and os.listdir(sample_dir):
185
- sample_image = os.path.join(sample_dir, os.listdir(sample_dir)[0])
186
- print(f"Analyzing sample image: {sample_image}")
187
-
188
- results = analyzer.analyze(sample_image)
189
- explanation = analyzer.get_explanation(results)
190
-
191
- print("\nAnalysis Results:")
192
- print(explanation)
193
- else:
194
- print("No sample images found in ../data/sample directory")
 
1
+ import logging
2
+ import os
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
7
+
8
+
9
+ class XRayImageAnalyzer:
10
+ """
11
+ A class for analyzing medical X-ray images using pre-trained models from Hugging Face.
12
+
13
+ This analyzer uses the DeiT (Data-efficient image Transformers) model fine-tuned
14
+ on chest X-ray images to detect abnormalities.
15
+ """
16
+
17
+ def __init__(
18
+ self, model_name="codewithdark/vit-chest-xray", device=None
19
+ ):
20
+ """
21
+ Initialize the X-ray image analyzer with a specific pre-trained model.
22
+
23
+ Args:
24
+ model_name (str): The Hugging Face model name to use
25
+ device (str, optional): Device to run the model on ('cuda' or 'cpu')
26
+ """
27
+ self.logger = logging.getLogger(__name__)
28
+
29
+ # Determine device (CPU or GPU)
30
+ if device is None:
31
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ else:
33
+ self.device = device
34
+
35
+ self.logger.info(f"Using device: {self.device}")
36
+
37
+ # Load model and feature extractor
38
+ try:
39
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
40
+ self.model = AutoModelForImageClassification.from_pretrained(model_name)
41
+ self.model.to(self.device)
42
+ self.model.eval() # Set to evaluation mode
43
+ self.logger.info(f"Successfully loaded model: {model_name}")
44
+
45
+ # Map labels to more informative descriptions
46
+ self.labels = self.model.config.id2label
47
+
48
+ except Exception as e:
49
+ self.logger.error(f"Failed to load model: {e}")
50
+ raise
51
+
52
+ def preprocess_image(self, image_path):
53
+ """
54
+ Preprocess an X-ray image for model input.
55
+
56
+ Args:
57
+ image_path (str or PIL.Image): Path to image or PIL Image object
58
+
59
+ Returns:
60
+ dict: Processed inputs ready for the model
61
+ """
62
+ try:
63
+ # Load image if path is provided
64
+ if isinstance(image_path, str):
65
+ if not os.path.exists(image_path):
66
+ raise FileNotFoundError(f"Image file not found: {image_path}")
67
+ image = Image.open(image_path).convert("RGB")
68
+ else:
69
+ # Assume it's already a PIL Image
70
+ image = image_path.convert("RGB")
71
+
72
+ # Apply feature extraction
73
+ inputs = self.feature_extractor(images=image, return_tensors="pt")
74
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
75
+
76
+ return inputs, image
77
+
78
+ except Exception as e:
79
+ self.logger.error(f"Error in preprocessing image: {e}")
80
+ raise
81
+
82
+ def analyze(self, image_path, threshold=0.5):
83
+ """
84
+ Analyze an X-ray image and detect abnormalities.
85
+
86
+ Args:
87
+ image_path (str or PIL.Image): Path to the X-ray image or PIL Image object
88
+ threshold (float): Classification threshold for positive findings
89
+
90
+ Returns:
91
+ dict: Analysis results including:
92
+ - predictions: List of (label, probability) tuples
93
+ - primary_finding: The most likely abnormality
94
+ - has_abnormality: Boolean indicating if abnormalities were detected
95
+ - confidence: Confidence score for the primary finding
96
+ """
97
+ try:
98
+ # Preprocess the image
99
+ inputs, original_image = self.preprocess_image(image_path)
100
+
101
+ # Run inference
102
+ with torch.no_grad():
103
+ outputs = self.model(**inputs)
104
+
105
+ # Process predictions
106
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
107
+ probabilities = probabilities.cpu().numpy()
108
+
109
+ # Get predictions sorted by probability
110
+ predictions = []
111
+ for i, p in enumerate(probabilities):
112
+ label = self.labels[i]
113
+ predictions.append((label, float(p)))
114
+
115
+ # Sort by probability (descending)
116
+ predictions.sort(key=lambda x: x[1], reverse=True)
117
+
118
+ # Determine if there's an abnormality and the primary finding
119
+ normal_idx = [
120
+ i
121
+ for i, (label, _) in enumerate(predictions)
122
+ if label.lower() == "normal" or label.lower() == "no finding"
123
+ ]
124
+
125
+ if normal_idx and predictions[normal_idx[0]][1] > threshold:
126
+ has_abnormality = False
127
+ primary_finding = "No abnormalities detected"
128
+ confidence = predictions[normal_idx[0]][1]
129
+ else:
130
+ has_abnormality = True
131
+ primary_finding = predictions[0][0]
132
+ confidence = predictions[0][1]
133
+
134
+ return {
135
+ "predictions": predictions,
136
+ "primary_finding": primary_finding,
137
+ "has_abnormality": has_abnormality,
138
+ "confidence": confidence,
139
+ }
140
+
141
+ except Exception as e:
142
+ self.logger.error(f"Error analyzing image: {e}")
143
+ raise
144
+
145
+ def get_explanation(self, results):
146
+ """
147
+ Generate a human-readable explanation of the analysis results.
148
+
149
+ Args:
150
+ results (dict): The results returned by the analyze method
151
+
152
+ Returns:
153
+ str: A text explanation of the findings
154
+ """
155
+ if not results["has_abnormality"]:
156
+ explanation = (
157
+ f"The X-ray appears normal with {results['confidence']:.1%} confidence."
158
+ )
159
+ else:
160
+ explanation = (
161
+ f"The primary finding is {results['primary_finding']} "
162
+ f"with {results['confidence']:.1%} confidence.\n\n"
163
+ f"Other potential findings include:\n"
164
+ )
165
+
166
+ # Add top 3 other findings (skipping the first one which is primary)
167
+ for label, prob in results["predictions"][1:4]:
168
+ if prob > 0.05: # Only include if probability > 5%
169
+ explanation += f"- {label}: {prob:.1%}\n"
170
+
171
+ return explanation
172
+
173
+
174
+ # Example usage
175
+ if __name__ == "__main__":
176
+ # Set up logging
177
+ logging.basicConfig(level=logging.INFO)
178
+
179
+ # Test on a sample image if available
180
+ analyzer = XRayImageAnalyzer()
181
+
182
+ # Check if sample data directory exists
183
+ sample_dir = "../data/sample"
184
+ if os.path.exists(sample_dir) and os.listdir(sample_dir):
185
+ sample_image = os.path.join(sample_dir, os.listdir(sample_dir)[0])
186
+ print(f"Analyzing sample image: {sample_image}")
187
+
188
+ results = analyzer.analyze(sample_image)
189
+ explanation = analyzer.get_explanation(results)
190
+
191
+ print("\nAnalysis Results:")
192
+ print(explanation)
193
+ else:
194
+ print("No sample images found in ../data/sample directory")