Amarthya7 commited on
Commit
86a74e6
·
verified ·
1 Parent(s): 681725a

Upload 21 files

Browse files
mediSync/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MediSync: Multi-Modal Medical Analysis System
3
+ ==============================================
4
+
5
+ A healthcare solution that combines X-ray image analysis with patient report text processing
6
+ to provide comprehensive medical insights.
7
+
8
+ This package contains the following modules:
9
+ - models: Image and text analysis models, along with multimodal fusion
10
+ - utils: Utility functions for preprocessing and visualization
11
+ - app: Main application with Gradio interface
12
+
13
+ Author: AI Development Team
14
+ License: MIT
15
+ """
16
+
17
+ __version__ = "0.1.0"
mediSync/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (712 Bytes). View file
 
mediSync/__pycache__/app.cpython-311.pyc ADDED
Binary file (25.9 kB). View file
 
mediSync/app.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ import tempfile
5
+ from pathlib import Path
6
+
7
+ import gradio as gr
8
+ import matplotlib.pyplot as plt
9
+ from PIL import Image
10
+
11
+ # Add parent directory to path
12
+ parent_dir = os.path.dirname(os.path.abspath(__file__))
13
+ sys.path.append(parent_dir)
14
+
15
+ # Import our modules
16
+ from models.multimodal_fusion import MultimodalFusion
17
+ from utils.preprocessing import enhance_xray_image, normalize_report_text
18
+ from utils.visualization import (
19
+ plot_image_prediction,
20
+ plot_multimodal_results,
21
+ plot_report_entities,
22
+ )
23
+
24
+ # Set up logging
25
+ logging.basicConfig(
26
+ level=logging.INFO,
27
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
28
+ handlers=[logging.StreamHandler(), logging.FileHandler("mediSync.log")],
29
+ )
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Create temporary directory for sample data if it doesn't exist
33
+ os.makedirs(os.path.join(parent_dir, "data", "sample"), exist_ok=True)
34
+
35
+
36
+ class MediSyncApp:
37
+ """
38
+ Main application class for the MediSync multi-modal medical analysis system.
39
+ """
40
+
41
+ def __init__(self):
42
+ """Initialize the application and load models."""
43
+ self.logger = logging.getLogger(__name__)
44
+ self.logger.info("Initializing MediSync application")
45
+
46
+ # Initialize models with None for lazy loading
47
+ self.fusion_model = None
48
+ self.image_model = None
49
+ self.text_model = None
50
+
51
+ def load_models(self):
52
+ """
53
+ Load models if not already loaded.
54
+
55
+ Returns:
56
+ bool: True if models loaded successfully, False otherwise
57
+ """
58
+ try:
59
+ if self.fusion_model is None:
60
+ self.logger.info("Loading models...")
61
+ self.fusion_model = MultimodalFusion()
62
+ self.image_model = self.fusion_model.image_analyzer
63
+ self.text_model = self.fusion_model.text_analyzer
64
+ self.logger.info("Models loaded successfully")
65
+ return True
66
+
67
+ except Exception as e:
68
+ self.logger.error(f"Error loading models: {e}")
69
+ return False
70
+
71
+ def analyze_image(self, image):
72
+ """
73
+ Analyze a medical image.
74
+
75
+ Args:
76
+ image: Image file uploaded through Gradio
77
+
78
+ Returns:
79
+ tuple: (image, image_results_html, plot_as_html)
80
+ """
81
+ try:
82
+ # Ensure models are loaded
83
+ if not self.load_models() or self.image_model is None:
84
+ return image, "Error: Models not loaded properly.", None
85
+
86
+ # Save uploaded image to a temporary file
87
+ temp_dir = tempfile.mkdtemp()
88
+ temp_path = os.path.join(temp_dir, "upload.png")
89
+
90
+ if isinstance(image, str):
91
+ # Copy the file if it's a path
92
+ from shutil import copyfile
93
+
94
+ copyfile(image, temp_path)
95
+ else:
96
+ # Save if it's a Gradio UploadButton image
97
+ image.save(temp_path)
98
+
99
+ # Run image analysis
100
+ self.logger.info(f"Analyzing image: {temp_path}")
101
+ results = self.image_model.analyze(temp_path)
102
+
103
+ # Create visualization
104
+ fig = plot_image_prediction(
105
+ image,
106
+ results.get("predictions", []),
107
+ f"Primary Finding: {results.get('primary_finding', 'Unknown')}",
108
+ )
109
+
110
+ # Convert to HTML for display
111
+ plot_html = self.fig_to_html(fig)
112
+
113
+ # Format results as HTML
114
+ html_result = f"""
115
+ <h2>X-ray Analysis Results</h2>
116
+ <p><strong>Primary Finding:</strong> {results.get("primary_finding", "Unknown")}</p>
117
+ <p><strong>Confidence:</strong> {results.get("confidence", 0):.1%}</p>
118
+ <p><strong>Abnormality Detected:</strong> {"Yes" if results.get("has_abnormality", False) else "No"}</p>
119
+
120
+ <h3>Top Predictions:</h3>
121
+ <ul>
122
+ """
123
+
124
+ # Add top 5 predictions
125
+ for label, prob in results.get("predictions", [])[:5]:
126
+ html_result += f"<li>{label}: {prob:.1%}</li>"
127
+
128
+ html_result += "</ul>"
129
+
130
+ # Add explanation
131
+ explanation = self.image_model.get_explanation(results)
132
+ html_result += f"<h3>Analysis Explanation:</h3><p>{explanation}</p>"
133
+
134
+ return image, html_result, plot_html
135
+
136
+ except Exception as e:
137
+ self.logger.error(f"Error in image analysis: {e}")
138
+ return image, f"Error analyzing image: {str(e)}", None
139
+
140
+ def analyze_text(self, text):
141
+ """
142
+ Analyze a medical report text.
143
+
144
+ Args:
145
+ text: Report text input through Gradio
146
+
147
+ Returns:
148
+ tuple: (text, text_results_html, entities_plot_html)
149
+ """
150
+ try:
151
+ # Ensure models are loaded
152
+ if not self.load_models() or self.text_model is None:
153
+ return text, "Error: Models not loaded properly.", None
154
+
155
+ # Check for empty text
156
+ if not text or len(text.strip()) < 10:
157
+ return (
158
+ text,
159
+ "Error: Please enter a valid medical report text (at least 10 characters).",
160
+ None,
161
+ )
162
+
163
+ # Normalize text
164
+ normalized_text = normalize_report_text(text)
165
+
166
+ # Run text analysis
167
+ self.logger.info("Analyzing medical report text")
168
+ results = self.text_model.analyze(normalized_text)
169
+
170
+ # Get entities and create visualization
171
+ entities = results.get("entities", {})
172
+ fig = plot_report_entities(normalized_text, entities)
173
+
174
+ # Convert to HTML for display
175
+ entities_plot_html = self.fig_to_html(fig)
176
+
177
+ # Format results as HTML
178
+ html_result = f"""
179
+ <h2>Medical Report Analysis Results</h2>
180
+ <p><strong>Severity Level:</strong> {results.get("severity", {}).get("level", "Unknown")}</p>
181
+ <p><strong>Severity Score:</strong> {results.get("severity", {}).get("score", 0)}/4</p>
182
+ <p><strong>Confidence:</strong> {results.get("severity", {}).get("confidence", 0):.1%}</p>
183
+
184
+ <h3>Key Findings:</h3>
185
+ <ul>
186
+ """
187
+
188
+ # Add findings
189
+ findings = results.get("findings", [])
190
+ if findings:
191
+ for finding in findings:
192
+ html_result += f"<li>{finding}</li>"
193
+ else:
194
+ html_result += "<li>No specific findings detailed.</li>"
195
+
196
+ html_result += "</ul>"
197
+
198
+ # Add entities
199
+ html_result += "<h3>Extracted Medical Entities:</h3>"
200
+
201
+ for category, items in entities.items():
202
+ if items:
203
+ html_result += f"<p><strong>{category.capitalize()}:</strong> {', '.join(items)}</p>"
204
+
205
+ # Add follow-up recommendations
206
+ html_result += "<h3>Follow-up Recommendations:</h3><ul>"
207
+ followups = results.get("followup_recommendations", [])
208
+
209
+ if followups:
210
+ for rec in followups:
211
+ html_result += f"<li>{rec}</li>"
212
+ else:
213
+ html_result += "<li>No specific follow-up recommendations.</li>"
214
+
215
+ html_result += "</ul>"
216
+
217
+ return text, html_result, entities_plot_html
218
+
219
+ except Exception as e:
220
+ self.logger.error(f"Error in text analysis: {e}")
221
+ return text, f"Error analyzing text: {str(e)}", None
222
+
223
+ def analyze_multimodal(self, image, text):
224
+ """
225
+ Perform multimodal analysis of image and text.
226
+
227
+ Args:
228
+ image: Image file uploaded through Gradio
229
+ text: Report text input through Gradio
230
+
231
+ Returns:
232
+ tuple: (results_html, multimodal_plot_html)
233
+ """
234
+ try:
235
+ # Ensure models are loaded
236
+ if not self.load_models() or self.fusion_model is None:
237
+ return "Error: Models not loaded properly.", None
238
+
239
+ # Check for empty inputs
240
+ if image is None:
241
+ return "Error: Please upload an X-ray image for analysis.", None
242
+
243
+ if not text or len(text.strip()) < 10:
244
+ return (
245
+ "Error: Please enter a valid medical report text (at least 10 characters).",
246
+ None,
247
+ )
248
+
249
+ # Save uploaded image to a temporary file
250
+ temp_dir = tempfile.mkdtemp()
251
+ temp_path = os.path.join(temp_dir, "upload.png")
252
+
253
+ if isinstance(image, str):
254
+ # Copy the file if it's a path
255
+ from shutil import copyfile
256
+
257
+ copyfile(image, temp_path)
258
+ else:
259
+ # Save if it's a Gradio UploadButton image
260
+ image.save(temp_path)
261
+
262
+ # Normalize text
263
+ normalized_text = normalize_report_text(text)
264
+
265
+ # Run multimodal analysis
266
+ self.logger.info("Performing multimodal analysis")
267
+ results = self.fusion_model.analyze(temp_path, normalized_text)
268
+
269
+ # Create visualization
270
+ fig = plot_multimodal_results(results, image, text)
271
+
272
+ # Convert to HTML for display
273
+ plot_html = self.fig_to_html(fig)
274
+
275
+ # Generate explanation
276
+ explanation = self.fusion_model.get_explanation(results)
277
+
278
+ # Format results as HTML
279
+ html_result = f"""
280
+ <h2>Multimodal Medical Analysis Results</h2>
281
+
282
+ <h3>Overview</h3>
283
+ <p><strong>Primary Finding:</strong> {results.get("primary_finding", "Unknown")}</p>
284
+ <p><strong>Severity Level:</strong> {results.get("severity", {}).get("level", "Unknown")}</p>
285
+ <p><strong>Severity Score:</strong> {results.get("severity", {}).get("score", 0)}/4</p>
286
+ <p><strong>Agreement Score:</strong> {results.get("agreement_score", 0):.0%}</p>
287
+
288
+ <h3>Detailed Findings</h3>
289
+ <ul>
290
+ """
291
+
292
+ # Add findings
293
+ findings = results.get("findings", [])
294
+ if findings:
295
+ for finding in findings:
296
+ html_result += f"<li>{finding}</li>"
297
+ else:
298
+ html_result += "<li>No specific findings detailed.</li>"
299
+
300
+ html_result += "</ul>"
301
+
302
+ # Add follow-up recommendations
303
+ html_result += "<h3>Recommended Follow-up</h3><ul>"
304
+ followups = results.get("followup_recommendations", [])
305
+
306
+ if followups:
307
+ for rec in followups:
308
+ html_result += f"<li>{rec}</li>"
309
+ else:
310
+ html_result += (
311
+ "<li>No specific follow-up recommendations provided.</li>"
312
+ )
313
+
314
+ html_result += "</ul>"
315
+
316
+ # Add confidence note
317
+ confidence = results.get("severity", {}).get("confidence", 0)
318
+ html_result += f"""
319
+ <p><em>Note: This analysis has a confidence level of {confidence:.0%}.
320
+ Please consult with healthcare professionals for official diagnosis.</em></p>
321
+ """
322
+
323
+ return html_result, plot_html
324
+
325
+ except Exception as e:
326
+ self.logger.error(f"Error in multimodal analysis: {e}")
327
+ return f"Error in multimodal analysis: {str(e)}", None
328
+
329
+ def enhance_image(self, image):
330
+ """
331
+ Enhance X-ray image contrast.
332
+
333
+ Args:
334
+ image: Image file uploaded through Gradio
335
+
336
+ Returns:
337
+ PIL.Image: Enhanced image
338
+ """
339
+ try:
340
+ if image is None:
341
+ return None
342
+
343
+ # Save uploaded image to a temporary file
344
+ temp_dir = tempfile.mkdtemp()
345
+ temp_path = os.path.join(temp_dir, "upload.png")
346
+
347
+ if isinstance(image, str):
348
+ # Copy the file if it's a path
349
+ from shutil import copyfile
350
+
351
+ copyfile(image, temp_path)
352
+ else:
353
+ # Save if it's a Gradio UploadButton image
354
+ image.save(temp_path)
355
+
356
+ # Enhance image
357
+ self.logger.info(f"Enhancing image: {temp_path}")
358
+ output_path = os.path.join(temp_dir, "enhanced.png")
359
+ enhance_xray_image(temp_path, output_path)
360
+
361
+ # Load enhanced image
362
+ enhanced = Image.open(output_path)
363
+ return enhanced
364
+
365
+ except Exception as e:
366
+ self.logger.error(f"Error enhancing image: {e}")
367
+ return image # Return original image on error
368
+
369
+ def fig_to_html(self, fig):
370
+ """Convert matplotlib figure to HTML for display in Gradio."""
371
+ try:
372
+ import base64
373
+ import io
374
+
375
+ buf = io.BytesIO()
376
+ fig.savefig(buf, format="png", bbox_inches="tight")
377
+ buf.seek(0)
378
+ img_str = base64.b64encode(buf.read()).decode("utf-8")
379
+ plt.close(fig)
380
+
381
+ return f'<img src="data:image/png;base64,{img_str}" alt="Analysis Plot">'
382
+
383
+ except Exception as e:
384
+ self.logger.error(f"Error converting figure to HTML: {e}")
385
+ return "<p>Error displaying visualization.</p>"
386
+
387
+
388
+ def create_interface():
389
+ """Create and launch the Gradio interface."""
390
+
391
+ app = MediSyncApp()
392
+
393
+ # Example medical report for demo
394
+ example_report = """
395
+ CHEST X-RAY EXAMINATION
396
+
397
+ CLINICAL HISTORY: 55-year-old male with cough and fever.
398
+
399
+ FINDINGS: The heart size is at the upper limits of normal. The lungs are clear without focal consolidation,
400
+ effusion, or pneumothorax. There is mild prominence of the pulmonary vasculature. No pleural effusion is seen.
401
+ There is a small nodular opacity noted in the right lower lobe measuring approximately 8mm, which is suspicious
402
+ and warrants further investigation. The mediastinum is unremarkable. The visualized bony structures show no acute abnormalities.
403
+
404
+ IMPRESSION:
405
+ 1. Mild cardiomegaly.
406
+ 2. 8mm nodular opacity in the right lower lobe, recommend follow-up CT for further evaluation.
407
+ 3. No acute pulmonary parenchymal abnormality.
408
+
409
+ RECOMMENDATIONS: Follow-up chest CT to further characterize the nodular opacity in the right lower lobe.
410
+ """
411
+
412
+ # Get sample image path if available
413
+ sample_images_dir = Path(parent_dir) / "data" / "sample"
414
+ sample_images = list(sample_images_dir.glob("*.png")) + list(
415
+ sample_images_dir.glob("*.jpg")
416
+ )
417
+
418
+ sample_image_path = None
419
+ if sample_images:
420
+ sample_image_path = str(sample_images[0])
421
+
422
+ # Define interface
423
+ with gr.Blocks(
424
+ title="MediSync: Multi-Modal Medical Analysis System", theme=gr.themes.Soft()
425
+ ) as interface:
426
+ gr.Markdown("""
427
+ # MediSync: Multi-Modal Medical Analysis System
428
+
429
+ This AI-powered healthcare solution combines X-ray image analysis with patient report text processing
430
+ to provide comprehensive medical insights.
431
+
432
+ ## How to Use
433
+ 1. Upload a chest X-ray image
434
+ 2. Enter the corresponding medical report text
435
+ 3. Choose the analysis type: image-only, text-only, or multimodal (combined)
436
+ """)
437
+
438
+ with gr.Tab("Multimodal Analysis"):
439
+ with gr.Row():
440
+ with gr.Column():
441
+ multi_img_input = gr.Image(label="Upload X-ray Image", type="pil")
442
+ multi_img_enhance = gr.Button("Enhance Image")
443
+
444
+ multi_text_input = gr.Textbox(
445
+ label="Enter Medical Report Text",
446
+ placeholder="Enter the radiologist's report text here...",
447
+ lines=10,
448
+ value=example_report if sample_image_path is None else None,
449
+ )
450
+
451
+ multi_analyze_btn = gr.Button(
452
+ "Analyze Image & Text", variant="primary"
453
+ )
454
+
455
+ with gr.Column():
456
+ multi_results = gr.HTML(label="Analysis Results")
457
+ multi_plot = gr.HTML(label="Visualization")
458
+
459
+ # Set up examples if sample image exists
460
+ if sample_image_path:
461
+ gr.Examples(
462
+ examples=[[sample_image_path, example_report]],
463
+ inputs=[multi_img_input, multi_text_input],
464
+ label="Example X-ray and Report",
465
+ )
466
+
467
+ with gr.Tab("Image Analysis"):
468
+ with gr.Row():
469
+ with gr.Column():
470
+ img_input = gr.Image(label="Upload X-ray Image", type="pil")
471
+ img_enhance = gr.Button("Enhance Image")
472
+ img_analyze_btn = gr.Button("Analyze Image", variant="primary")
473
+
474
+ with gr.Column():
475
+ img_output = gr.Image(label="Processed Image")
476
+ img_results = gr.HTML(label="Analysis Results")
477
+ img_plot = gr.HTML(label="Visualization")
478
+
479
+ # Set up example if sample image exists
480
+ if sample_image_path:
481
+ gr.Examples(
482
+ examples=[[sample_image_path]],
483
+ inputs=[img_input],
484
+ label="Example X-ray Image",
485
+ )
486
+
487
+ with gr.Tab("Text Analysis"):
488
+ with gr.Row():
489
+ with gr.Column():
490
+ text_input = gr.Textbox(
491
+ label="Enter Medical Report Text",
492
+ placeholder="Enter the radiologist's report text here...",
493
+ lines=10,
494
+ value=example_report,
495
+ )
496
+ text_analyze_btn = gr.Button("Analyze Text", variant="primary")
497
+
498
+ with gr.Column():
499
+ text_output = gr.Textbox(label="Processed Text")
500
+ text_results = gr.HTML(label="Analysis Results")
501
+ text_plot = gr.HTML(label="Entity Visualization")
502
+
503
+ # Set up example
504
+ gr.Examples(
505
+ examples=[[example_report]],
506
+ inputs=[text_input],
507
+ label="Example Medical Report",
508
+ )
509
+
510
+ with gr.Tab("About"):
511
+ gr.Markdown("""
512
+ ## About MediSync
513
+
514
+ MediSync is an AI-powered healthcare solution that uses multi-modal analysis to provide comprehensive insights from medical images and reports.
515
+
516
+ ### Key Features
517
+
518
+ - **X-ray Image Analysis**: Detects abnormalities in chest X-rays using pre-trained vision models
519
+ - **Medical Report Processing**: Extracts key information from patient reports using NLP models
520
+ - **Multi-modal Integration**: Combines insights from both image and text data for more accurate analysis
521
+
522
+ ### Models Used
523
+
524
+ - **X-ray Analysis**: facebook/deit-base-patch16-224-medical-cxr
525
+ - **Medical Text Analysis**: medicalai/ClinicalBERT
526
+
527
+ ### Important Disclaimer
528
+
529
+ This tool is for educational and research purposes only. It is not intended to provide medical advice or replace professional healthcare. Always consult with qualified healthcare providers for medical decisions.
530
+ """)
531
+
532
+ # Set up event handlers
533
+ multi_img_enhance.click(
534
+ app.enhance_image, inputs=multi_img_input, outputs=multi_img_input
535
+ )
536
+ multi_analyze_btn.click(
537
+ app.analyze_multimodal,
538
+ inputs=[multi_img_input, multi_text_input],
539
+ outputs=[multi_results, multi_plot],
540
+ )
541
+
542
+ img_enhance.click(app.enhance_image, inputs=img_input, outputs=img_output)
543
+ img_analyze_btn.click(
544
+ app.analyze_image,
545
+ inputs=img_input,
546
+ outputs=[img_output, img_results, img_plot],
547
+ )
548
+
549
+ text_analyze_btn.click(
550
+ app.analyze_text,
551
+ inputs=text_input,
552
+ outputs=[text_output, text_results, text_plot],
553
+ )
554
+
555
+ # Run the interface
556
+ interface.launch()
557
+
558
+
559
+ if __name__ == "__main__":
560
+ create_interface()
mediSync/data/sample/sample_info.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sample X-ray Images
2
+
3
+ ## normal_chest_xray.jpg
4
+ Description: Normal chest X-ray
5
+ Source: https://prod-images-static.radiopaedia.org/images/53448173/322830a37f0fa0852773ca2db3e8d8_big_gallery.jpeg
6
+
7
+ ## pneumonia_xray.jpg
8
+ Description: X-ray with pneumonia
9
+ Source: https://prod-images-static.radiopaedia.org/images/52465460/e4d8791bd7502ab72af8d9e5c322db_big_gallery.jpg
10
+
11
+ ## cardiomegaly_xray.jpg
12
+ Description: X-ray with cardiomegaly
13
+ Source: https://prod-images-static.radiopaedia.org/images/556520/cf17c05750adb04b2a6e23afb47c7d_big_gallery.jpg
14
+
15
+ ## nodule_xray.jpg
16
+ Description: X-ray with lung nodule
17
+ Source: https://prod-images-static.radiopaedia.org/images/19972291/41eed1a2cdad06d26c3f415a6ed65a_big_gallery.jpeg
18
+
19
+
20
+ These images are used for testing and demonstration purposes only.
21
+ Please note that these images are from public medical education sources.
22
+ Do not use for clinical decision making.
mediSync/models/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MediSync: Models Module
3
+ =======================
4
+
5
+ This module contains the core machine learning models for the MediSync system:
6
+
7
+ 1. XRayImageAnalyzer: Analyzes X-ray images using pre-trained vision models
8
+ 2. MedicalReportAnalyzer: Extracts information from medical reports using NLP models
9
+ 3. MultimodalFusion: Combines insights from both image and text analysis
10
+ """
11
+
12
+ from .image_analyzer import XRayImageAnalyzer
13
+ from .multimodal_fusion import MultimodalFusion
14
+ from .text_analyzer import MedicalReportAnalyzer
15
+
16
+ __all__ = ["XRayImageAnalyzer", "MedicalReportAnalyzer", "MultimodalFusion"]
mediSync/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (826 Bytes). View file
 
mediSync/models/__pycache__/image_analyzer.cpython-311.pyc ADDED
Binary file (9.32 kB). View file
 
mediSync/models/__pycache__/multimodal_fusion.cpython-311.pyc ADDED
Binary file (24.9 kB). View file
 
mediSync/models/__pycache__/text_analyzer.cpython-311.pyc ADDED
Binary file (19 kB). View file
 
mediSync/models/image_analyzer.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
7
+
8
+
9
+ class XRayImageAnalyzer:
10
+ """
11
+ A class for analyzing medical X-ray images using pre-trained models from Hugging Face.
12
+
13
+ This analyzer uses the DeiT (Data-efficient image Transformers) model fine-tuned
14
+ on chest X-ray images to detect abnormalities.
15
+ """
16
+
17
+ def __init__(
18
+ self, model_name="facebook/deit-base-patch16-224-medical-cxr", device=None
19
+ ):
20
+ """
21
+ Initialize the X-ray image analyzer with a specific pre-trained model.
22
+
23
+ Args:
24
+ model_name (str): The Hugging Face model name to use
25
+ device (str, optional): Device to run the model on ('cuda' or 'cpu')
26
+ """
27
+ self.logger = logging.getLogger(__name__)
28
+
29
+ # Determine device (CPU or GPU)
30
+ if device is None:
31
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ else:
33
+ self.device = device
34
+
35
+ self.logger.info(f"Using device: {self.device}")
36
+
37
+ # Load model and feature extractor
38
+ try:
39
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
40
+ self.model = AutoModelForImageClassification.from_pretrained(model_name)
41
+ self.model.to(self.device)
42
+ self.model.eval() # Set to evaluation mode
43
+ self.logger.info(f"Successfully loaded model: {model_name}")
44
+
45
+ # Map labels to more informative descriptions
46
+ self.labels = self.model.config.id2label
47
+
48
+ except Exception as e:
49
+ self.logger.error(f"Failed to load model: {e}")
50
+ raise
51
+
52
+ def preprocess_image(self, image_path):
53
+ """
54
+ Preprocess an X-ray image for model input.
55
+
56
+ Args:
57
+ image_path (str or PIL.Image): Path to image or PIL Image object
58
+
59
+ Returns:
60
+ dict: Processed inputs ready for the model
61
+ """
62
+ try:
63
+ # Load image if path is provided
64
+ if isinstance(image_path, str):
65
+ if not os.path.exists(image_path):
66
+ raise FileNotFoundError(f"Image file not found: {image_path}")
67
+ image = Image.open(image_path).convert("RGB")
68
+ else:
69
+ # Assume it's already a PIL Image
70
+ image = image_path.convert("RGB")
71
+
72
+ # Apply feature extraction
73
+ inputs = self.feature_extractor(images=image, return_tensors="pt")
74
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
75
+
76
+ return inputs, image
77
+
78
+ except Exception as e:
79
+ self.logger.error(f"Error in preprocessing image: {e}")
80
+ raise
81
+
82
+ def analyze(self, image_path, threshold=0.5):
83
+ """
84
+ Analyze an X-ray image and detect abnormalities.
85
+
86
+ Args:
87
+ image_path (str or PIL.Image): Path to the X-ray image or PIL Image object
88
+ threshold (float): Classification threshold for positive findings
89
+
90
+ Returns:
91
+ dict: Analysis results including:
92
+ - predictions: List of (label, probability) tuples
93
+ - primary_finding: The most likely abnormality
94
+ - has_abnormality: Boolean indicating if abnormalities were detected
95
+ - confidence: Confidence score for the primary finding
96
+ """
97
+ try:
98
+ # Preprocess the image
99
+ inputs, original_image = self.preprocess_image(image_path)
100
+
101
+ # Run inference
102
+ with torch.no_grad():
103
+ outputs = self.model(**inputs)
104
+
105
+ # Process predictions
106
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
107
+ probabilities = probabilities.cpu().numpy()
108
+
109
+ # Get predictions sorted by probability
110
+ predictions = []
111
+ for i, p in enumerate(probabilities):
112
+ label = self.labels[i]
113
+ predictions.append((label, float(p)))
114
+
115
+ # Sort by probability (descending)
116
+ predictions.sort(key=lambda x: x[1], reverse=True)
117
+
118
+ # Determine if there's an abnormality and the primary finding
119
+ normal_idx = [
120
+ i
121
+ for i, (label, _) in enumerate(predictions)
122
+ if label.lower() == "normal" or label.lower() == "no finding"
123
+ ]
124
+
125
+ if normal_idx and predictions[normal_idx[0]][1] > threshold:
126
+ has_abnormality = False
127
+ primary_finding = "No abnormalities detected"
128
+ confidence = predictions[normal_idx[0]][1]
129
+ else:
130
+ has_abnormality = True
131
+ primary_finding = predictions[0][0]
132
+ confidence = predictions[0][1]
133
+
134
+ return {
135
+ "predictions": predictions,
136
+ "primary_finding": primary_finding,
137
+ "has_abnormality": has_abnormality,
138
+ "confidence": confidence,
139
+ }
140
+
141
+ except Exception as e:
142
+ self.logger.error(f"Error analyzing image: {e}")
143
+ raise
144
+
145
+ def get_explanation(self, results):
146
+ """
147
+ Generate a human-readable explanation of the analysis results.
148
+
149
+ Args:
150
+ results (dict): The results returned by the analyze method
151
+
152
+ Returns:
153
+ str: A text explanation of the findings
154
+ """
155
+ if not results["has_abnormality"]:
156
+ explanation = (
157
+ f"The X-ray appears normal with {results['confidence']:.1%} confidence."
158
+ )
159
+ else:
160
+ explanation = (
161
+ f"The primary finding is {results['primary_finding']} "
162
+ f"with {results['confidence']:.1%} confidence.\n\n"
163
+ f"Other potential findings include:\n"
164
+ )
165
+
166
+ # Add top 3 other findings (skipping the first one which is primary)
167
+ for label, prob in results["predictions"][1:4]:
168
+ if prob > 0.05: # Only include if probability > 5%
169
+ explanation += f"- {label}: {prob:.1%}\n"
170
+
171
+ return explanation
172
+
173
+
174
+ # Example usage
175
+ if __name__ == "__main__":
176
+ # Set up logging
177
+ logging.basicConfig(level=logging.INFO)
178
+
179
+ # Test on a sample image if available
180
+ analyzer = XRayImageAnalyzer()
181
+
182
+ # Check if sample data directory exists
183
+ sample_dir = "../data/sample"
184
+ if os.path.exists(sample_dir) and os.listdir(sample_dir):
185
+ sample_image = os.path.join(sample_dir, os.listdir(sample_dir)[0])
186
+ print(f"Analyzing sample image: {sample_image}")
187
+
188
+ results = analyzer.analyze(sample_image)
189
+ explanation = analyzer.get_explanation(results)
190
+
191
+ print("\nAnalysis Results:")
192
+ print(explanation)
193
+ else:
194
+ print("No sample images found in ../data/sample directory")
mediSync/models/multimodal_fusion.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from .image_analyzer import XRayImageAnalyzer
4
+ from .text_analyzer import MedicalReportAnalyzer
5
+
6
+
7
+ class MultimodalFusion:
8
+ """
9
+ A class for fusing insights from image analysis and text analysis of medical data.
10
+
11
+ This fusion approach combines the strengths of both modalities:
12
+ - Images provide visual evidence of abnormalities
13
+ - Text reports provide context, history and radiologist interpretations
14
+
15
+ The combined analysis provides a more comprehensive understanding than either modality alone.
16
+ """
17
+
18
+ def __init__(self, image_model=None, text_model=None, device=None):
19
+ """
20
+ Initialize the multimodal fusion module with image and text analyzers.
21
+
22
+ Args:
23
+ image_model (str, optional): Model to use for image analysis
24
+ text_model (str, optional): Model to use for text analysis
25
+ device (str, optional): Device to run models on ('cuda' or 'cpu')
26
+ """
27
+ self.logger = logging.getLogger(__name__)
28
+
29
+ # Determine device
30
+ if device is None:
31
+ import torch
32
+
33
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ else:
35
+ self.device = device
36
+
37
+ self.logger.info(f"Using device: {self.device}")
38
+
39
+ # Initialize image analyzer
40
+ try:
41
+ self.image_analyzer = XRayImageAnalyzer(
42
+ model_name=image_model
43
+ if image_model
44
+ else "facebook/deit-base-patch16-224-medical-cxr",
45
+ device=self.device,
46
+ )
47
+ self.logger.info("Successfully initialized image analyzer")
48
+ except Exception as e:
49
+ self.logger.error(f"Failed to initialize image analyzer: {e}")
50
+ self.image_analyzer = None
51
+
52
+ # Initialize text analyzer
53
+ try:
54
+ self.text_analyzer = MedicalReportAnalyzer(
55
+ classifier_model=text_model if text_model else "medicalai/ClinicalBERT",
56
+ device=self.device,
57
+ )
58
+ self.logger.info("Successfully initialized text analyzer")
59
+ except Exception as e:
60
+ self.logger.error(f"Failed to initialize text analyzer: {e}")
61
+ self.text_analyzer = None
62
+
63
+ def analyze_image(self, image_path):
64
+ """
65
+ Analyze a medical image.
66
+
67
+ Args:
68
+ image_path (str): Path to the medical image
69
+
70
+ Returns:
71
+ dict: Image analysis results
72
+ """
73
+ if not self.image_analyzer:
74
+ self.logger.warning("Image analyzer not available")
75
+ return {"error": "Image analyzer not available"}
76
+
77
+ try:
78
+ return self.image_analyzer.analyze(image_path)
79
+ except Exception as e:
80
+ self.logger.error(f"Error analyzing image: {e}")
81
+ return {"error": str(e)}
82
+
83
+ def analyze_text(self, text):
84
+ """
85
+ Analyze medical report text.
86
+
87
+ Args:
88
+ text (str): Medical report text
89
+
90
+ Returns:
91
+ dict: Text analysis results
92
+ """
93
+ if not self.text_analyzer:
94
+ self.logger.warning("Text analyzer not available")
95
+ return {"error": "Text analyzer not available"}
96
+
97
+ try:
98
+ return self.text_analyzer.analyze(text)
99
+ except Exception as e:
100
+ self.logger.error(f"Error analyzing text: {e}")
101
+ return {"error": str(e)}
102
+
103
+ def _calculate_agreement_score(self, image_results, text_results):
104
+ """
105
+ Calculate agreement score between image and text analyses.
106
+
107
+ Args:
108
+ image_results (dict): Results from image analysis
109
+ text_results (dict): Results from text analysis
110
+
111
+ Returns:
112
+ float: Agreement score (0-1, where 1 is perfect agreement)
113
+ """
114
+ try:
115
+ # Default to neutral agreement
116
+ agreement = 0.5
117
+
118
+ # Check if image detected abnormality
119
+ image_abnormal = image_results.get("has_abnormality", False)
120
+
121
+ # Check text severity
122
+ text_severity = text_results.get("severity", {}).get("level", "Unknown")
123
+ text_abnormal = text_severity not in ["Normal", "Unknown"]
124
+
125
+ # Basic agreement check
126
+ if image_abnormal == text_abnormal:
127
+ agreement += 0.25
128
+ else:
129
+ agreement -= 0.25
130
+
131
+ # Check if specific findings match
132
+ image_finding = image_results.get("primary_finding", "").lower()
133
+
134
+ # Extract problem entities from text
135
+ problems = text_results.get("entities", {}).get("problem", [])
136
+ problem_text = " ".join(problems).lower()
137
+
138
+ # Check for common keywords in both
139
+ common_conditions = [
140
+ "pneumonia",
141
+ "effusion",
142
+ "nodule",
143
+ "mass",
144
+ "cardiomegaly",
145
+ "opacity",
146
+ "fracture",
147
+ "tumor",
148
+ "edema",
149
+ ]
150
+
151
+ matching_conditions = 0
152
+ total_mentioned = 0
153
+
154
+ for condition in common_conditions:
155
+ in_image = condition in image_finding
156
+ in_text = condition in problem_text
157
+
158
+ if in_image or in_text:
159
+ total_mentioned += 1
160
+
161
+ if in_image and in_text:
162
+ matching_conditions += 1
163
+ agreement += 0.05 # Boost agreement for each matching condition
164
+
165
+ # Calculate condition match ratio if any conditions were mentioned
166
+ if total_mentioned > 0:
167
+ match_ratio = matching_conditions / total_mentioned
168
+ agreement += match_ratio * 0.2
169
+
170
+ # Normalize agreement to 0-1 range
171
+ agreement = max(0, min(1, agreement))
172
+
173
+ return agreement
174
+
175
+ except Exception as e:
176
+ self.logger.error(f"Error calculating agreement score: {e}")
177
+ return 0.5 # Return neutral agreement on error
178
+
179
+ def _get_confidence_weighted_finding(self, image_results, text_results, agreement):
180
+ """
181
+ Get the most confident finding weighted by modality confidence.
182
+
183
+ Args:
184
+ image_results (dict): Results from image analysis
185
+ text_results (dict): Results from text analysis
186
+ agreement (float): Agreement score between modalities
187
+
188
+ Returns:
189
+ str: Most confident finding
190
+ """
191
+ try:
192
+ image_finding = image_results.get("primary_finding", "")
193
+ image_confidence = image_results.get("confidence", 0.5)
194
+
195
+ # For text, use the most severe problem as primary finding
196
+ problems = text_results.get("entities", {}).get("problem", [])
197
+
198
+ text_confidence = text_results.get("severity", {}).get("confidence", 0.5)
199
+
200
+ if not problems:
201
+ # No problems identified in text
202
+ if image_confidence > 0.7:
203
+ return image_finding
204
+ else:
205
+ return "No significant findings"
206
+
207
+ # Simple confidence-weighted selection
208
+ if image_confidence > text_confidence + 0.2:
209
+ return image_finding
210
+ elif problems and text_confidence > image_confidence + 0.2:
211
+ return (
212
+ problems[0]
213
+ if isinstance(problems, list) and problems
214
+ else "Unknown finding"
215
+ )
216
+ else:
217
+ # Similar confidence, check agreement
218
+ if agreement > 0.7:
219
+ # High agreement, try to find the specific condition mentioned in both
220
+ for problem in problems:
221
+ if problem.lower() in image_finding.lower():
222
+ return problem
223
+
224
+ # Default to image finding if high confidence
225
+ if image_confidence > 0.6:
226
+ return image_finding
227
+ elif problems:
228
+ return problems[0]
229
+ else:
230
+ return image_finding
231
+ else:
232
+ # Low agreement, include both perspectives
233
+ if image_finding and problems:
234
+ return f"{image_finding} (image) / {problems[0]} (report)"
235
+ elif image_finding:
236
+ return image_finding
237
+ elif problems:
238
+ return problems[0]
239
+ else:
240
+ return "Findings unclear - review recommended"
241
+
242
+ except Exception as e:
243
+ self.logger.error(f"Error getting weighted finding: {e}")
244
+ return "Unable to determine primary finding"
245
+
246
+ def _merge_followup_recommendations(self, image_results, text_results):
247
+ """
248
+ Merge follow-up recommendations from both modalities.
249
+
250
+ Args:
251
+ image_results (dict): Results from image analysis
252
+ text_results (dict): Results from text analysis
253
+
254
+ Returns:
255
+ list: Combined follow-up recommendations
256
+ """
257
+ try:
258
+ # Get text-based recommendations
259
+ text_recommendations = text_results.get("followup_recommendations", [])
260
+
261
+ # Create image-based recommendations based on findings
262
+ image_recommendations = []
263
+
264
+ if image_results.get("has_abnormality", False):
265
+ primary = image_results.get("primary_finding", "")
266
+ confidence = image_results.get("confidence", 0)
267
+
268
+ if (
269
+ "nodule" in primary.lower()
270
+ or "mass" in primary.lower()
271
+ or "tumor" in primary.lower()
272
+ ):
273
+ image_recommendations.append(
274
+ f"Follow-up imaging recommended to further evaluate {primary}."
275
+ )
276
+ elif "pneumonia" in primary.lower():
277
+ image_recommendations.append(
278
+ "Clinical correlation and follow-up imaging recommended."
279
+ )
280
+ elif confidence > 0.8:
281
+ image_recommendations.append(
282
+ f"Consider follow-up imaging to monitor {primary}."
283
+ )
284
+ elif confidence > 0.5:
285
+ image_recommendations.append(
286
+ "Consider clinical correlation and potential follow-up."
287
+ )
288
+
289
+ # Combine recommendations, removing duplicates
290
+ all_recommendations = text_recommendations + image_recommendations
291
+
292
+ # Remove near-duplicates (similar recommendations)
293
+ unique_recommendations = []
294
+ for rec in all_recommendations:
295
+ if not any(
296
+ self._is_similar_recommendation(rec, existing)
297
+ for existing in unique_recommendations
298
+ ):
299
+ unique_recommendations.append(rec)
300
+
301
+ return unique_recommendations
302
+
303
+ except Exception as e:
304
+ self.logger.error(f"Error merging follow-up recommendations: {e}")
305
+ return ["Follow-up recommended based on findings."]
306
+
307
+ def _is_similar_recommendation(self, rec1, rec2):
308
+ """Check if two recommendations are semantically similar."""
309
+ # Convert to lowercase for comparison
310
+ rec1_lower = rec1.lower()
311
+ rec2_lower = rec2.lower()
312
+
313
+ # Check for significant overlap
314
+ words1 = set(rec1_lower.split())
315
+ words2 = set(rec2_lower.split())
316
+
317
+ # Calculate Jaccard similarity
318
+ intersection = words1.intersection(words2)
319
+ union = words1.union(words2)
320
+
321
+ similarity = len(intersection) / len(union) if union else 0
322
+
323
+ # Consider similar if more than 60% overlap
324
+ return similarity > 0.6
325
+
326
+ def _get_final_severity(self, image_results, text_results, agreement):
327
+ """
328
+ Determine final severity based on both modalities.
329
+
330
+ Args:
331
+ image_results (dict): Results from image analysis
332
+ text_results (dict): Results from text analysis
333
+ agreement (float): Agreement score between modalities
334
+
335
+ Returns:
336
+ dict: Final severity assessment
337
+ """
338
+ try:
339
+ # Get text-based severity
340
+ text_severity = text_results.get("severity", {})
341
+ text_level = text_severity.get("level", "Unknown")
342
+ text_score = text_severity.get("score", 0)
343
+ text_confidence = text_severity.get("confidence", 0.5)
344
+
345
+ # Convert image findings to severity
346
+ image_abnormal = image_results.get("has_abnormality", False)
347
+ image_confidence = image_results.get("confidence", 0.5)
348
+
349
+ # Default severity mapping from image
350
+ image_severity = "Normal" if not image_abnormal else "Moderate"
351
+ image_score = 0 if not image_abnormal else 2.0
352
+
353
+ # Adjust image severity based on specific findings
354
+ primary_finding = image_results.get("primary_finding", "").lower()
355
+
356
+ # Map certain conditions to severity levels
357
+ severity_mapping = {
358
+ "pneumonia": ("Moderate", 2.5),
359
+ "pneumothorax": ("Severe", 3.0),
360
+ "effusion": ("Moderate", 2.0),
361
+ "pulmonary edema": ("Moderate", 2.5),
362
+ "nodule": ("Mild", 1.5),
363
+ "mass": ("Moderate", 2.5),
364
+ "tumor": ("Severe", 3.0),
365
+ "cardiomegaly": ("Mild", 1.5),
366
+ "fracture": ("Moderate", 2.0),
367
+ "consolidation": ("Moderate", 2.0),
368
+ }
369
+
370
+ # Check if any key terms are in the primary finding
371
+ for key, (severity, score) in severity_mapping.items():
372
+ if key in primary_finding:
373
+ image_severity = severity
374
+ image_score = score
375
+ break
376
+
377
+ # Weight based on confidence and agreement
378
+ if agreement > 0.7:
379
+ # High agreement - weight equally
380
+ final_score = (image_score + text_score) / 2
381
+ else:
382
+ # Lower agreement - weight by confidence
383
+ total_confidence = image_confidence + text_confidence
384
+ if total_confidence > 0:
385
+ image_weight = image_confidence / total_confidence
386
+ text_weight = text_confidence / total_confidence
387
+ final_score = (image_score * image_weight) + (
388
+ text_score * text_weight
389
+ )
390
+ else:
391
+ final_score = (image_score + text_score) / 2
392
+
393
+ # Map score to severity level
394
+ severity_levels = {
395
+ 0: "Normal",
396
+ 1: "Mild",
397
+ 2: "Moderate",
398
+ 3: "Severe",
399
+ 4: "Critical",
400
+ }
401
+
402
+ # Round to nearest level
403
+ level_index = round(min(4, max(0, final_score)))
404
+ final_level = severity_levels[level_index]
405
+
406
+ return {
407
+ "level": final_level,
408
+ "score": round(final_score, 1),
409
+ "confidence": round((image_confidence + text_confidence) / 2, 2),
410
+ }
411
+
412
+ except Exception as e:
413
+ self.logger.error(f"Error determining final severity: {e}")
414
+ return {"level": "Unknown", "score": 0, "confidence": 0}
415
+
416
+ def fuse_analyses(self, image_results, text_results):
417
+ """
418
+ Fuse the results from image and text analyses.
419
+
420
+ Args:
421
+ image_results (dict): Results from image analysis
422
+ text_results (dict): Results from text analysis
423
+
424
+ Returns:
425
+ dict: Fused analysis results
426
+ """
427
+ try:
428
+ # Calculate agreement between modalities
429
+ agreement = self._calculate_agreement_score(image_results, text_results)
430
+ self.logger.info(f"Agreement score between modalities: {agreement:.2f}")
431
+
432
+ # Get confidence-weighted primary finding
433
+ primary_finding = self._get_confidence_weighted_finding(
434
+ image_results, text_results, agreement
435
+ )
436
+
437
+ # Merge follow-up recommendations
438
+ followup = self._merge_followup_recommendations(image_results, text_results)
439
+
440
+ # Get final severity assessment
441
+ severity = self._get_final_severity(image_results, text_results, agreement)
442
+
443
+ # Create comprehensive findings list
444
+ findings = []
445
+
446
+ # Add text-extracted findings
447
+ text_findings = text_results.get("findings", [])
448
+ if text_findings:
449
+ findings.extend(text_findings)
450
+
451
+ # Add primary image finding if not already included
452
+ image_finding = image_results.get("primary_finding", "")
453
+ if image_finding and not any(
454
+ image_finding.lower() in f.lower() for f in findings
455
+ ):
456
+ findings.append(f"Image finding: {image_finding}")
457
+
458
+ # Create fused result
459
+ fused_result = {
460
+ "agreement_score": round(agreement, 2),
461
+ "primary_finding": primary_finding,
462
+ "severity": severity,
463
+ "findings": findings,
464
+ "followup_recommendations": followup,
465
+ "modality_results": {"image": image_results, "text": text_results},
466
+ }
467
+
468
+ return fused_result
469
+
470
+ except Exception as e:
471
+ self.logger.error(f"Error fusing analyses: {e}")
472
+ return {
473
+ "error": str(e),
474
+ "modality_results": {"image": image_results, "text": text_results},
475
+ }
476
+
477
+ def analyze(self, image_path, report_text):
478
+ """
479
+ Perform multimodal analysis of medical image and report.
480
+
481
+ Args:
482
+ image_path (str): Path to the medical image
483
+ report_text (str): Medical report text
484
+
485
+ Returns:
486
+ dict: Fused analysis results
487
+ """
488
+ try:
489
+ # Analyze image
490
+ image_results = self.analyze_image(image_path)
491
+
492
+ # Analyze text
493
+ text_results = self.analyze_text(report_text)
494
+
495
+ # Fuse the analyses
496
+ return self.fuse_analyses(image_results, text_results)
497
+
498
+ except Exception as e:
499
+ self.logger.error(f"Error in multimodal analysis: {e}")
500
+ return {"error": str(e)}
501
+
502
+ def get_explanation(self, fused_results):
503
+ """
504
+ Generate a human-readable explanation of the fused analysis.
505
+
506
+ Args:
507
+ fused_results (dict): Results from the fused analysis
508
+
509
+ Returns:
510
+ str: A text explanation of the fused analysis
511
+ """
512
+ try:
513
+ explanation = []
514
+
515
+ # Add overview section
516
+ primary_finding = fused_results.get("primary_finding", "Unknown")
517
+ severity = fused_results.get("severity", {}).get("level", "Unknown")
518
+
519
+ explanation.append("# Medical Analysis Summary\n")
520
+ explanation.append("## Overview\n")
521
+ explanation.append(f"Primary finding: **{primary_finding}**\n")
522
+ explanation.append(f"Severity level: **{severity}**\n")
523
+
524
+ # Add agreement information
525
+ agreement = fused_results.get("agreement_score", 0)
526
+ agreement_text = (
527
+ "High" if agreement > 0.7 else "Moderate" if agreement > 0.4 else "Low"
528
+ )
529
+
530
+ explanation.append(
531
+ f"Image and text analysis agreement: **{agreement_text}** ({agreement:.0%})\n"
532
+ )
533
+
534
+ # Add findings section
535
+ explanation.append("\n## Detailed Findings\n")
536
+ findings = fused_results.get("findings", [])
537
+
538
+ if findings:
539
+ for finding in findings:
540
+ explanation.append(f"- {finding}\n")
541
+ else:
542
+ explanation.append("No specific findings detailed.\n")
543
+
544
+ # Add follow-up section
545
+ explanation.append("\n## Recommended Follow-up\n")
546
+ followups = fused_results.get("followup_recommendations", [])
547
+
548
+ if followups:
549
+ for followup in followups:
550
+ explanation.append(f"- {followup}\n")
551
+ else:
552
+ explanation.append("No specific follow-up recommendations provided.\n")
553
+
554
+ # Add confidence note
555
+ confidence = fused_results.get("severity", {}).get("confidence", 0)
556
+ explanation.append(
557
+ f"\n*Note: This analysis has a confidence level of {confidence:.0%}. "
558
+ f"Please consult with healthcare professionals for official diagnosis.*"
559
+ )
560
+
561
+ return "\n".join(explanation)
562
+
563
+ except Exception as e:
564
+ self.logger.error(f"Error generating explanation: {e}")
565
+ return "Error generating analysis explanation."
566
+
567
+
568
+ # Example usage
569
+ if __name__ == "__main__":
570
+ # Set up logging
571
+ logging.basicConfig(level=logging.INFO)
572
+
573
+ # Test on sample data if available
574
+ import os
575
+
576
+ fusion = MultimodalFusion()
577
+
578
+ # Sample text report
579
+ sample_report = """
580
+ CHEST X-RAY EXAMINATION
581
+
582
+ CLINICAL HISTORY: 55-year-old male with cough and fever.
583
+
584
+ FINDINGS: The heart size is at the upper limits of normal. The lungs are clear without focal consolidation,
585
+ effusion, or pneumothorax. There is mild prominence of the pulmonary vasculature. No pleural effusion is seen.
586
+ There is a small nodular opacity noted in the right lower lobe measuring approximately 8mm, which is suspicious
587
+ and warrants further investigation. The mediastinum is unremarkable. The visualized bony structures show no acute abnormalities.
588
+
589
+ IMPRESSION:
590
+ 1. Mild cardiomegaly.
591
+ 2. 8mm nodular opacity in the right lower lobe, recommend follow-up CT for further evaluation.
592
+ 3. No acute pulmonary parenchymal abnormality.
593
+
594
+ RECOMMENDATIONS: Follow-up chest CT to further characterize the nodular opacity in the right lower lobe.
595
+ """
596
+
597
+ # Check if sample data directory exists and contains images
598
+ sample_dir = "../data/sample"
599
+ if os.path.exists(sample_dir) and os.listdir(sample_dir):
600
+ sample_image = os.path.join(sample_dir, os.listdir(sample_dir)[0])
601
+ print(f"Analyzing sample image: {sample_image}")
602
+
603
+ # Perform multimodal analysis
604
+ fused_results = fusion.analyze(sample_image, sample_report)
605
+ explanation = fusion.get_explanation(fused_results)
606
+
607
+ print("\nFused Analysis Results:")
608
+ print(explanation)
609
+ else:
610
+ print("No sample images found. Only analyzing text report.")
611
+
612
+ # Analyze just the text
613
+ text_results = fusion.analyze_text(sample_report)
614
+
615
+ print("\nText Analysis Results:")
616
+ print(
617
+ f"Severity: {text_results['severity']['level']} (Score: {text_results['severity']['score']})"
618
+ )
619
+
620
+ print("\nKey Findings:")
621
+ for finding in text_results["findings"]:
622
+ print(f"- {finding}")
623
+
624
+ print("\nEntities:")
625
+ for category, items in text_results["entities"].items():
626
+ if items:
627
+ print(f"- {category.capitalize()}: {', '.join(items)}")
628
+
629
+ print("\nFollow-up Recommendations:")
630
+ for rec in text_results["followup_recommendations"]:
631
+ print(f"- {rec}")
mediSync/models/text_analyzer.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+
4
+ import torch
5
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
+
7
+
8
+ class MedicalReportAnalyzer:
9
+ """
10
+ A class for analyzing medical text reports using pre-trained NLP models from Hugging Face.
11
+
12
+ This analyzer can:
13
+ 1. Extract medical entities (conditions, treatments, tests)
14
+ 2. Classify report severity
15
+ 3. Extract key findings
16
+ 4. Identify suggested follow-up actions
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ ner_model="samrawal/bert-base-uncased_medical-ner",
22
+ classifier_model="medicalai/ClinicalBERT",
23
+ device=None,
24
+ ):
25
+ """
26
+ Initialize the text analyzer with specific pre-trained models.
27
+
28
+ Args:
29
+ ner_model (str): Model for named entity recognition
30
+ classifier_model (str): Model for text classification
31
+ device (str, optional): Device to run models on ('cuda' or 'cpu')
32
+ """
33
+ self.logger = logging.getLogger(__name__)
34
+
35
+ # Determine device
36
+ if device is None:
37
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ else:
39
+ self.device = device
40
+
41
+ self.logger.info(f"Using device: {self.device}")
42
+
43
+ # Load NER model for entity extraction
44
+ try:
45
+ self.ner_pipeline = pipeline(
46
+ "token-classification",
47
+ model=ner_model,
48
+ aggregation_strategy="simple",
49
+ device=0 if self.device == "cuda" else -1,
50
+ )
51
+ self.logger.info(f"Successfully loaded NER model: {ner_model}")
52
+ except Exception as e:
53
+ self.logger.error(f"Failed to load NER model: {e}")
54
+ self.ner_pipeline = None
55
+
56
+ # Load classifier model for severity assessment
57
+ try:
58
+ self.tokenizer = AutoTokenizer.from_pretrained(classifier_model)
59
+ self.classifier = AutoModelForSequenceClassification.from_pretrained(
60
+ classifier_model
61
+ )
62
+ self.classifier.to(self.device)
63
+ self.classifier.eval()
64
+ self.logger.info(
65
+ f"Successfully loaded classifier model: {classifier_model}"
66
+ )
67
+ except Exception as e:
68
+ self.logger.error(f"Failed to load classifier model: {e}")
69
+ self.classifier = None
70
+
71
+ # Severity levels mapping
72
+ self.severity_levels = {
73
+ 0: "Normal",
74
+ 1: "Mild",
75
+ 2: "Moderate",
76
+ 3: "Severe",
77
+ 4: "Critical",
78
+ }
79
+
80
+ # Common medical findings and their severity levels
81
+ self.finding_severity = {
82
+ "pneumonia": 3,
83
+ "fracture": 3,
84
+ "tumor": 4,
85
+ "nodule": 2,
86
+ "mass": 3,
87
+ "edema": 2,
88
+ "effusion": 2,
89
+ "hemorrhage": 3,
90
+ "opacity": 1,
91
+ "atelectasis": 2,
92
+ "pneumothorax": 3,
93
+ "consolidation": 2,
94
+ "cardiomegaly": 2,
95
+ }
96
+
97
+ def extract_entities(self, text):
98
+ """
99
+ Extract medical entities from the report text.
100
+
101
+ Args:
102
+ text (str): Medical report text
103
+
104
+ Returns:
105
+ dict: Dictionary of entity lists by category
106
+ """
107
+ if not self.ner_pipeline:
108
+ self.logger.warning("NER model not available")
109
+ return {}
110
+
111
+ try:
112
+ # Run NER
113
+ entities = self.ner_pipeline(text)
114
+
115
+ # Group entities by type
116
+ grouped_entities = {
117
+ "problem": [], # Medical conditions
118
+ "test": [], # Tests/procedures
119
+ "treatment": [], # Treatments/medications
120
+ "anatomy": [], # Anatomical locations
121
+ }
122
+
123
+ for entity in entities:
124
+ entity_type = entity.get("entity_group", "").lower()
125
+
126
+ # Map entity types to our categories
127
+ if entity_type in ["problem", "disease", "condition", "diagnosis"]:
128
+ category = "problem"
129
+ elif entity_type in ["test", "procedure", "examination"]:
130
+ category = "test"
131
+ elif entity_type in ["treatment", "medication", "drug"]:
132
+ category = "treatment"
133
+ elif entity_type in ["body_part", "anatomy", "organ"]:
134
+ category = "anatomy"
135
+ else:
136
+ continue # Skip other entity types
137
+
138
+ word = entity.get("word", "")
139
+ score = entity.get("score", 0)
140
+
141
+ # Only include if confidence is reasonable
142
+ if score > 0.7 and word not in grouped_entities[category]:
143
+ grouped_entities[category].append(word)
144
+
145
+ return grouped_entities
146
+
147
+ except Exception as e:
148
+ self.logger.error(f"Error extracting entities: {e}")
149
+ return {}
150
+
151
+ def assess_severity(self, text):
152
+ """
153
+ Assess the severity level of the medical report.
154
+
155
+ Args:
156
+ text (str): Medical report text
157
+
158
+ Returns:
159
+ dict: Severity assessment including level and confidence
160
+ """
161
+ if not self.classifier:
162
+ self.logger.warning("Classifier model not available")
163
+ return {"level": "Unknown", "score": 0.0}
164
+
165
+ try:
166
+ # Use rule-based approach along with model
167
+ severity_score = 0
168
+ confidence = 0.5 # Start with neutral confidence
169
+
170
+ # Check for severe keywords
171
+ severe_keywords = [
172
+ "severe",
173
+ "critical",
174
+ "urgent",
175
+ "emergency",
176
+ "immediate attention",
177
+ ]
178
+ moderate_keywords = ["moderate", "concerning", "follow-up", "monitor"]
179
+ mild_keywords = ["mild", "minimal", "slight", "minor"]
180
+ normal_keywords = [
181
+ "normal",
182
+ "unremarkable",
183
+ "no abnormalities",
184
+ "within normal limits",
185
+ ]
186
+
187
+ # Count keyword occurrences
188
+ text_lower = text.lower()
189
+ severe_count = sum(text_lower.count(word) for word in severe_keywords)
190
+ moderate_count = sum(text_lower.count(word) for word in moderate_keywords)
191
+ mild_count = sum(text_lower.count(word) for word in mild_keywords)
192
+ normal_count = sum(text_lower.count(word) for word in normal_keywords)
193
+
194
+ # Adjust severity based on keyword counts
195
+ if severe_count > 0:
196
+ severity_score += min(severe_count, 2) * 1.5
197
+ confidence += 0.1
198
+ if moderate_count > 0:
199
+ severity_score += min(moderate_count, 3) * 0.75
200
+ confidence += 0.05
201
+ if mild_count > 0:
202
+ severity_score += min(mild_count, 3) * 0.25
203
+ confidence += 0.05
204
+ if normal_count > 0:
205
+ severity_score -= min(normal_count, 3) * 0.75
206
+ confidence += 0.1
207
+
208
+ # Check for specific medical findings
209
+ for finding, level in self.finding_severity.items():
210
+ if finding in text_lower:
211
+ severity_score += level * 0.5
212
+ confidence += 0.05
213
+
214
+ # Normalize severity score to 0-4 range
215
+ severity_score = max(0, min(4, severity_score))
216
+ severity_level = int(round(severity_score))
217
+
218
+ # Map to severity level
219
+ severity = self.severity_levels.get(severity_level, "Moderate")
220
+
221
+ # Cap confidence at 0.95
222
+ confidence = min(0.95, confidence)
223
+
224
+ return {
225
+ "level": severity,
226
+ "score": round(severity_score, 1),
227
+ "confidence": round(confidence, 2),
228
+ }
229
+
230
+ except Exception as e:
231
+ self.logger.error(f"Error assessing severity: {e}")
232
+ return {"level": "Unknown", "score": 0.0, "confidence": 0.0}
233
+
234
+ def extract_findings(self, text):
235
+ """
236
+ Extract key clinical findings from the report.
237
+
238
+ Args:
239
+ text (str): Medical report text
240
+
241
+ Returns:
242
+ list: List of key findings
243
+ """
244
+ try:
245
+ # Split text into sentences
246
+ sentences = re.split(r"[.!?]\s+", text)
247
+ findings = []
248
+
249
+ # Key phrases that often introduce findings
250
+ finding_markers = [
251
+ "finding",
252
+ "observed",
253
+ "noted",
254
+ "shows",
255
+ "reveals",
256
+ "demonstrates",
257
+ "indicates",
258
+ "evident",
259
+ "apparent",
260
+ "consistent with",
261
+ "suggestive of",
262
+ ]
263
+
264
+ # Negative markers
265
+ negation_markers = ["no", "not", "none", "negative", "without", "denies"]
266
+
267
+ for sentence in sentences:
268
+ # Skip very short sentences
269
+ if len(sentence.split()) < 3:
270
+ continue
271
+
272
+ sentence = sentence.strip()
273
+
274
+ # Check if this sentence likely contains a finding
275
+ contains_finding_marker = any(
276
+ marker in sentence.lower() for marker in finding_markers
277
+ )
278
+
279
+ # Check for negation
280
+ contains_negation = any(
281
+ marker in sentence.lower().split() for marker in negation_markers
282
+ )
283
+
284
+ # Only include positive findings or explicitly negated findings that are important
285
+ if contains_finding_marker or (
286
+ contains_negation
287
+ and any(
288
+ term in sentence.lower()
289
+ for term in self.finding_severity.keys()
290
+ )
291
+ ):
292
+ findings.append(sentence)
293
+
294
+ return findings
295
+
296
+ except Exception as e:
297
+ self.logger.error(f"Error extracting findings: {e}")
298
+ return []
299
+
300
+ def suggest_followup(self, text, entities, severity):
301
+ """
302
+ Suggest follow-up actions based on report analysis.
303
+
304
+ Args:
305
+ text (str): Medical report text
306
+ entities (dict): Extracted entities
307
+ severity (dict): Severity assessment
308
+
309
+ Returns:
310
+ list: Suggested follow-up actions
311
+ """
312
+ try:
313
+ followups = []
314
+
315
+ # Base recommendations on severity
316
+ severity_level = severity.get("level", "Unknown")
317
+ severity_score = severity.get("score", 0)
318
+
319
+ # Extract problems from entities
320
+ problems = entities.get("problem", [])
321
+
322
+ # Check if follow-up is already mentioned in the text
323
+ followup_mentioned = any(
324
+ phrase in text.lower()
325
+ for phrase in [
326
+ "follow up",
327
+ "follow-up",
328
+ "followup",
329
+ "return",
330
+ "refer",
331
+ "consult",
332
+ ]
333
+ )
334
+
335
+ # Default recommendations based on severity
336
+ if severity_level == "Critical":
337
+ followups.append("Immediate specialist consultation recommended.")
338
+
339
+ elif severity_level == "Severe":
340
+ followups.append("Prompt follow-up with specialist is recommended.")
341
+
342
+ # Add specific recommendations for common severe conditions
343
+ for problem in problems:
344
+ if "pneumonia" in problem.lower():
345
+ followups.append(
346
+ "Consider antibiotic therapy and close monitoring."
347
+ )
348
+ elif "fracture" in problem.lower():
349
+ followups.append(
350
+ "Orthopedic consultation for treatment planning."
351
+ )
352
+ elif "mass" in problem.lower() or "tumor" in problem.lower():
353
+ followups.append(
354
+ "Further imaging and possible biopsy recommended."
355
+ )
356
+
357
+ elif severity_level == "Moderate":
358
+ followups.append("Follow-up with primary care physician recommended.")
359
+ if not followup_mentioned and problems:
360
+ followups.append(
361
+ "Consider additional imaging or tests for further evaluation."
362
+ )
363
+
364
+ elif severity_level == "Mild":
365
+ if problems:
366
+ followups.append(
367
+ "Routine follow-up with primary care physician as needed."
368
+ )
369
+ else:
370
+ followups.append("No immediate follow-up required.")
371
+
372
+ else: # Normal
373
+ followups.append(
374
+ "No specific follow-up indicated based on this report."
375
+ )
376
+
377
+ # Check for specific findings that always need follow-up
378
+ for critical_term in ["mass", "tumor", "nodule", "opacity"]:
379
+ if (
380
+ critical_term in text.lower()
381
+ and "follow-up" not in " ".join(followups).lower()
382
+ ):
383
+ followups.append(
384
+ f"Follow-up imaging recommended to monitor {critical_term}."
385
+ )
386
+ break
387
+
388
+ return followups
389
+
390
+ except Exception as e:
391
+ self.logger.error(f"Error suggesting follow-up: {e}")
392
+ return ["Unable to generate follow-up recommendations."]
393
+
394
+ def analyze(self, text):
395
+ """
396
+ Perform comprehensive analysis of medical report text.
397
+
398
+ Args:
399
+ text (str): Medical report text
400
+
401
+ Returns:
402
+ dict: Complete analysis results
403
+ """
404
+ try:
405
+ # Extract entities
406
+ entities = self.extract_entities(text)
407
+
408
+ # Assess severity
409
+ severity = self.assess_severity(text)
410
+
411
+ # Extract key findings
412
+ findings = self.extract_findings(text)
413
+
414
+ # Generate follow-up suggestions
415
+ followups = self.suggest_followup(text, entities, severity)
416
+
417
+ # Create detailed report
418
+ report = {
419
+ "entities": entities,
420
+ "severity": severity,
421
+ "findings": findings,
422
+ "followup_recommendations": followups,
423
+ }
424
+
425
+ return report
426
+
427
+ except Exception as e:
428
+ self.logger.error(f"Error analyzing report: {e}")
429
+ return {"error": str(e)}
430
+
431
+
432
+ # Example usage
433
+ if __name__ == "__main__":
434
+ # Set up logging
435
+ logging.basicConfig(level=logging.INFO)
436
+
437
+ # Test on a sample report
438
+ analyzer = MedicalReportAnalyzer()
439
+
440
+ sample_report = """
441
+ CHEST X-RAY EXAMINATION
442
+
443
+ CLINICAL HISTORY: 55-year-old male with cough and fever.
444
+
445
+ FINDINGS: The heart size is at the upper limits of normal. The lungs are clear without focal consolidation,
446
+ effusion, or pneumothorax. There is mild prominence of the pulmonary vasculature. No pleural effusion is seen.
447
+ There is a small nodular opacity noted in the right lower lobe measuring approximately 8mm, which is suspicious
448
+ and warrants further investigation. The mediastinum is unremarkable. The visualized bony structures show no acute abnormalities.
449
+
450
+ IMPRESSION:
451
+ 1. Mild cardiomegaly.
452
+ 2. 8mm nodular opacity in the right lower lobe, recommend follow-up CT for further evaluation.
453
+ 3. No acute pulmonary parenchymal abnormality.
454
+
455
+ RECOMMENDATIONS: Follow-up chest CT to further characterize the nodular opacity in the right lower lobe.
456
+ """
457
+
458
+ results = analyzer.analyze(sample_report)
459
+
460
+ print("\nMedical Report Analysis:")
461
+ print(
462
+ f"\nSeverity: {results['severity']['level']} (Score: {results['severity']['score']})"
463
+ )
464
+
465
+ print("\nKey Findings:")
466
+ for finding in results["findings"]:
467
+ print(f"- {finding}")
468
+
469
+ print("\nEntities:")
470
+ for category, items in results["entities"].items():
471
+ if items:
472
+ print(f"- {category.capitalize()}: {', '.join(items)}")
473
+
474
+ print("\nFollow-up Recommendations:")
475
+ for rec in results["followup_recommendations"]:
476
+ print(f"- {rec}")
mediSync/utils/__init__.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MediSync: Utils Module
3
+ =====================
4
+
5
+ This module contains utility functions for the MediSync system:
6
+
7
+ 1. preprocessing: Functions for preprocessing images and text
8
+ 2. visualization: Functions for visualizing analysis results
9
+ 3. download_samples: Functions for downloading sample data
10
+ """
11
+
12
+ from .preprocessing import (
13
+ enhance_xray_image,
14
+ extract_measurements,
15
+ extract_sections,
16
+ normalize_report_text,
17
+ preprocess_image,
18
+ )
19
+ from .visualization import (
20
+ create_heatmap_overlay,
21
+ figure_to_base64,
22
+ plot_image_prediction,
23
+ plot_multimodal_results,
24
+ plot_report_entities,
25
+ )
26
+
27
+ __all__ = [
28
+ "preprocess_image",
29
+ "normalize_report_text",
30
+ "enhance_xray_image",
31
+ "extract_sections",
32
+ "extract_measurements",
33
+ "plot_image_prediction",
34
+ "plot_report_entities",
35
+ "plot_multimodal_results",
36
+ "create_heatmap_overlay",
37
+ "figure_to_base64",
38
+ ]
mediSync/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.07 kB). View file
 
mediSync/utils/__pycache__/download_samples.cpython-311.pyc ADDED
Binary file (5.76 kB). View file
 
mediSync/utils/__pycache__/preprocessing.cpython-311.pyc ADDED
Binary file (9.26 kB). View file
 
mediSync/utils/__pycache__/visualization.cpython-311.pyc ADDED
Binary file (18 kB). View file
 
mediSync/utils/download_samples.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import urllib.request
3
+ from pathlib import Path
4
+
5
+ # Set up logging
6
+ logging.basicConfig(
7
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
8
+ )
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # Sample X-ray image URLs (from public sources)
12
+ SAMPLE_IMAGES = [
13
+ # Normal chest X-ray
14
+ {
15
+ "url": "https://prod-images-static.radiopaedia.org/images/53448173/322830a37f0fa0852773ca2db3e8d8_big_gallery.jpeg",
16
+ "filename": "normal_chest_xray.jpg",
17
+ "description": "Normal chest X-ray",
18
+ },
19
+ # X-ray with pneumonia
20
+ {
21
+ "url": "https://prod-images-static.radiopaedia.org/images/52465460/e4d8791bd7502ab72af8d9e5c322db_big_gallery.jpg",
22
+ "filename": "pneumonia_xray.jpg",
23
+ "description": "X-ray with pneumonia",
24
+ },
25
+ # X-ray with cardiomegaly
26
+ {
27
+ "url": "https://prod-images-static.radiopaedia.org/images/556520/cf17c05750adb04b2a6e23afb47c7d_big_gallery.jpg",
28
+ "filename": "cardiomegaly_xray.jpg",
29
+ "description": "X-ray with cardiomegaly",
30
+ },
31
+ # X-ray with lung nodule
32
+ {
33
+ "url": "https://prod-images-static.radiopaedia.org/images/19972291/41eed1a2cdad06d26c3f415a6ed65a_big_gallery.jpeg",
34
+ "filename": "nodule_xray.jpg",
35
+ "description": "X-ray with lung nodule",
36
+ },
37
+ ]
38
+
39
+
40
+ def download_sample_images(output_dir="data/sample"):
41
+ """
42
+ Download sample X-ray images for testing.
43
+
44
+ Args:
45
+ output_dir (str): Directory to save images
46
+
47
+ Returns:
48
+ list: Paths to downloaded images
49
+ """
50
+ # Get the directory of the script
51
+ script_dir = Path(__file__).resolve().parent.parent
52
+
53
+ # Create output directory if it doesn't exist
54
+ output_path = script_dir / output_dir
55
+ output_path.mkdir(parents=True, exist_ok=True)
56
+
57
+ downloaded_paths = []
58
+
59
+ for image in SAMPLE_IMAGES:
60
+ try:
61
+ filename = image["filename"]
62
+ url = image["url"]
63
+ output_file = output_path / filename
64
+
65
+ # Skip if file already exists
66
+ if output_file.exists():
67
+ logger.info(f"File already exists: {output_file}")
68
+ downloaded_paths.append(str(output_file))
69
+ continue
70
+
71
+ # Download the image
72
+ logger.info(f"Downloading {url} to {output_file}")
73
+
74
+ # Set a user agent to avoid blocking
75
+ opener = urllib.request.build_opener()
76
+ opener.addheaders = [("User-Agent", "Mozilla/5.0")]
77
+ urllib.request.install_opener(opener)
78
+
79
+ # Download the file
80
+ urllib.request.urlretrieve(url, output_file)
81
+
82
+ logger.info(f"Successfully downloaded {filename}")
83
+ downloaded_paths.append(str(output_file))
84
+
85
+ except Exception as e:
86
+ logger.error(f"Error downloading {image['url']}: {e}")
87
+
88
+ logger.info(
89
+ f"Downloaded {len(downloaded_paths)} out of {len(SAMPLE_IMAGES)} images"
90
+ )
91
+ return downloaded_paths
92
+
93
+
94
+ def create_sample_info_file(output_dir="data/sample"):
95
+ """
96
+ Create a text file with information about the sample images.
97
+
98
+ Args:
99
+ output_dir (str): Directory with sample images
100
+ """
101
+ # Get the directory of the script
102
+ script_dir = Path(__file__).resolve().parent.parent
103
+
104
+ # Output path
105
+ output_path = script_dir / output_dir
106
+ info_file = output_path / "sample_info.txt"
107
+
108
+ with open(info_file, "w") as f:
109
+ f.write("# Sample X-ray Images\n\n")
110
+
111
+ for image in SAMPLE_IMAGES:
112
+ f.write(f"## {image['filename']}\n")
113
+ f.write(f"Description: {image['description']}\n")
114
+ f.write(f"Source: {image['url']}\n\n")
115
+
116
+ f.write(
117
+ "\nThese images are used for testing and demonstration purposes only.\n"
118
+ )
119
+ f.write(
120
+ "Please note that these images are from public medical education sources.\n"
121
+ )
122
+ f.write("Do not use for clinical decision making.\n")
123
+
124
+ logger.info(f"Created sample info file: {info_file}")
125
+
126
+
127
+ if __name__ == "__main__":
128
+ # Download sample images
129
+ downloaded_paths = download_sample_images()
130
+
131
+ # Create info file
132
+ create_sample_info_file()
133
+
134
+ print(f"Downloaded {len(downloaded_paths)} sample images.")
135
+ print("Run the application with: python app.py")
mediSync/utils/preprocessing.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+
5
+ import cv2
6
+ from PIL import Image
7
+
8
+ # Set up logging
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def preprocess_image(image_path, target_size=(224, 224)):
13
+ """
14
+ Preprocess X-ray image for model input.
15
+
16
+ Args:
17
+ image_path (str): Path to the X-ray image
18
+ target_size (tuple): Target size for resizing
19
+
20
+ Returns:
21
+ PIL.Image: Preprocessed image
22
+ """
23
+ try:
24
+ # Check if file exists
25
+ if not os.path.exists(image_path):
26
+ raise FileNotFoundError(f"Image file not found: {image_path}")
27
+
28
+ # Load image
29
+ image = Image.open(image_path)
30
+
31
+ # Convert grayscale to RGB if needed
32
+ if image.mode != "RGB":
33
+ image = image.convert("RGB")
34
+
35
+ # Resize image
36
+ image = image.resize(target_size, Image.LANCZOS)
37
+
38
+ return image
39
+
40
+ except Exception as e:
41
+ logger.error(f"Error preprocessing image: {e}")
42
+ raise
43
+
44
+
45
+ def enhance_xray_image(image_path, output_path=None, clahe_clip=2.0, clahe_grid=(8, 8)):
46
+ """
47
+ Enhance X-ray image contrast using CLAHE (Contrast Limited Adaptive Histogram Equalization).
48
+
49
+ Args:
50
+ image_path (str): Path to the X-ray image
51
+ output_path (str, optional): Path to save enhanced image
52
+ clahe_clip (float): Clip limit for CLAHE
53
+ clahe_grid (tuple): Grid size for CLAHE
54
+
55
+ Returns:
56
+ str or np.ndarray: Path to enhanced image or image array
57
+ """
58
+ try:
59
+ # Read image
60
+ img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
61
+
62
+ if img is None:
63
+ raise ValueError(f"Failed to read image: {image_path}")
64
+
65
+ # Create CLAHE object
66
+ clahe = cv2.createCLAHE(clipLimit=clahe_clip, tileGridSize=clahe_grid)
67
+
68
+ # Apply CLAHE
69
+ enhanced = clahe.apply(img)
70
+
71
+ # Save enhanced image if output path is provided
72
+ if output_path:
73
+ cv2.imwrite(output_path, enhanced)
74
+ return output_path
75
+ else:
76
+ return enhanced
77
+
78
+ except Exception as e:
79
+ logger.error(f"Error enhancing X-ray image: {e}")
80
+ raise
81
+
82
+
83
+ def normalize_report_text(text):
84
+ """
85
+ Normalize medical report text for consistent processing.
86
+
87
+ Args:
88
+ text (str): Medical report text
89
+
90
+ Returns:
91
+ str: Normalized text
92
+ """
93
+ try:
94
+ # Remove multiple whitespaces
95
+ text = re.sub(r"\s+", " ", text)
96
+
97
+ # Standardize section headers
98
+ section_patterns = {
99
+ r"(?i)clinical\s*(?:history|indication)": "CLINICAL HISTORY:",
100
+ r"(?i)technique": "TECHNIQUE:",
101
+ r"(?i)comparison": "COMPARISON:",
102
+ r"(?i)findings": "FINDINGS:",
103
+ r"(?i)impression": "IMPRESSION:",
104
+ r"(?i)recommendation": "RECOMMENDATION:",
105
+ r"(?i)comment": "COMMENT:",
106
+ }
107
+
108
+ for pattern, replacement in section_patterns.items():
109
+ text = re.sub(pattern + r"\s*:", replacement, text)
110
+
111
+ # Standardize common abbreviations
112
+ abbrev_patterns = {
113
+ r"(?i)\bw\/\b": "with",
114
+ r"(?i)\bw\/o\b": "without",
115
+ r"(?i)\bs\/p\b": "status post",
116
+ r"(?i)\bc\/w\b": "consistent with",
117
+ r"(?i)\br\/o\b": "rule out",
118
+ r"(?i)\bhx\b": "history",
119
+ r"(?i)\bdx\b": "diagnosis",
120
+ r"(?i)\btx\b": "treatment",
121
+ }
122
+
123
+ for pattern, replacement in abbrev_patterns.items():
124
+ text = re.sub(pattern, replacement, text)
125
+
126
+ return text.strip()
127
+
128
+ except Exception as e:
129
+ logger.error(f"Error normalizing report text: {e}")
130
+ return text # Return original text if normalization fails
131
+
132
+
133
+ def extract_sections(text):
134
+ """
135
+ Extract sections from a medical report.
136
+
137
+ Args:
138
+ text (str): Medical report text
139
+
140
+ Returns:
141
+ dict: Dictionary of extracted sections
142
+ """
143
+ try:
144
+ # Normalize text first
145
+ normalized_text = normalize_report_text(text)
146
+
147
+ # Define section patterns
148
+ section_headers = [
149
+ "CLINICAL HISTORY:",
150
+ "TECHNIQUE:",
151
+ "COMPARISON:",
152
+ "FINDINGS:",
153
+ "IMPRESSION:",
154
+ "RECOMMENDATION:",
155
+ ]
156
+
157
+ # Find all section headers in the text
158
+ sections = {}
159
+ current_section = "PREAMBLE" # For text before first section header
160
+ sections[current_section] = []
161
+
162
+ for line in normalized_text.split("\n"):
163
+ section_found = False
164
+
165
+ for header in section_headers:
166
+ if header in line:
167
+ current_section = header.rstrip(":")
168
+ sections[current_section] = []
169
+ section_found = True
170
+ # Add the rest of the line after the header
171
+ content = line.split(header, 1)[1].strip()
172
+ if content:
173
+ sections[current_section].append(content)
174
+ break
175
+
176
+ if not section_found and current_section:
177
+ sections[current_section].append(line)
178
+
179
+ # Join each section's lines
180
+ for section, lines in sections.items():
181
+ sections[section] = " ".join(lines).strip()
182
+
183
+ # Remove empty sections
184
+ sections = {k: v for k, v in sections.items() if v}
185
+
186
+ return sections
187
+
188
+ except Exception as e:
189
+ logger.error(f"Error extracting sections: {e}")
190
+ return {"FULL_TEXT": text} # Return full text if extraction fails
191
+
192
+
193
+ def extract_measurements(text):
194
+ """
195
+ Extract measurements from medical text (sizes, volumes, etc.).
196
+
197
+ Args:
198
+ text (str): Medical text
199
+
200
+ Returns:
201
+ list: List of tuples containing (measurement, value, unit)
202
+ """
203
+ try:
204
+ # Pattern for measurements like "5mm nodule" or "nodule measuring 5mm"
205
+ # or "8x10mm mass" or "mass of size 8x10mm"
206
+ size_pattern = r"(\d+(?:\.\d+)?(?:\s*[x×]\s*\d+(?:\.\d+)?)?(?:\s*[x×]\s*\d+(?:\.\d+)?)?)\s*(mm|cm|mm2|cm2|mm3|cm3|ml|cc)"
207
+
208
+ # Find measurements with context
209
+ context_pattern = (
210
+ r"([A-Za-z\s]+(?:mass|nodule|effusion|opacity|lesion|tumor|cyst|structure|area|region)[A-Za-z\s]*)"
211
+ + size_pattern
212
+ )
213
+
214
+ context_measurements = []
215
+ for match in re.finditer(context_pattern, text, re.IGNORECASE):
216
+ context, size, unit = match.groups()
217
+ context_measurements.append((context.strip(), size, unit))
218
+
219
+ # For measurements without clear context, just extract size and unit
220
+ all_measurements = []
221
+ for match in re.finditer(size_pattern, text):
222
+ size, unit = match.groups()
223
+ all_measurements.append((size, unit))
224
+
225
+ return context_measurements
226
+
227
+ except Exception as e:
228
+ logger.error(f"Error extracting measurements: {e}")
229
+ return []
230
+
231
+
232
+ def prepare_sample_batch(image_paths, reports=None, target_size=(224, 224)):
233
+ """
234
+ Prepare a batch of samples for model processing.
235
+
236
+ Args:
237
+ image_paths (list): List of paths to images
238
+ reports (list, optional): List of corresponding reports
239
+ target_size (tuple): Target image size
240
+
241
+ Returns:
242
+ tuple: Batch of preprocessed images and reports
243
+ """
244
+ try:
245
+ processed_images = []
246
+ processed_reports = []
247
+
248
+ for i, image_path in enumerate(image_paths):
249
+ # Process image
250
+ image = preprocess_image(image_path, target_size)
251
+ processed_images.append(image)
252
+
253
+ # Process report if available
254
+ if reports and i < len(reports):
255
+ normalized_report = normalize_report_text(reports[i])
256
+ processed_reports.append(normalized_report)
257
+
258
+ return processed_images, processed_reports if reports else None
259
+
260
+ except Exception as e:
261
+ logger.error(f"Error preparing sample batch: {e}")
262
+ raise
mediSync/utils/visualization.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import logging
4
+
5
+ import cv2
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ # Set up logging
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def plot_image_prediction(image, predictions, title=None, figsize=(10, 8)):
15
+ """
16
+ Plot an image with its predictions.
17
+
18
+ Args:
19
+ image (PIL.Image or str): Image or path to image
20
+ predictions (list): List of (label, probability) tuples
21
+ title (str, optional): Plot title
22
+ figsize (tuple): Figure size
23
+
24
+ Returns:
25
+ matplotlib.figure.Figure: The figure object
26
+ """
27
+ try:
28
+ # Load image if path is provided
29
+ if isinstance(image, str):
30
+ img = Image.open(image)
31
+ else:
32
+ img = image
33
+
34
+ # Create figure
35
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
36
+
37
+ # Plot image
38
+ ax1.imshow(img)
39
+ ax1.set_title("X-ray Image")
40
+ ax1.axis("off")
41
+
42
+ # Plot predictions
43
+ if predictions:
44
+ # Sort predictions by probability
45
+ sorted_pred = sorted(predictions, key=lambda x: x[1], reverse=True)
46
+
47
+ # Get top 5 predictions
48
+ top_n = min(5, len(sorted_pred))
49
+ labels = [pred[0] for pred in sorted_pred[:top_n]]
50
+ probs = [pred[1] for pred in sorted_pred[:top_n]]
51
+
52
+ # Plot horizontal bar chart
53
+ y_pos = np.arange(top_n)
54
+ ax2.barh(y_pos, probs, align="center")
55
+ ax2.set_yticks(y_pos)
56
+ ax2.set_yticklabels(labels)
57
+ ax2.set_xlabel("Probability")
58
+ ax2.set_title("Top Predictions")
59
+ ax2.set_xlim(0, 1)
60
+
61
+ # Annotate probabilities
62
+ for i, prob in enumerate(probs):
63
+ ax2.text(prob + 0.02, i, f"{prob:.1%}", va="center")
64
+
65
+ # Set overall title
66
+ if title:
67
+ fig.suptitle(title, fontsize=16)
68
+
69
+ fig.tight_layout()
70
+ return fig
71
+
72
+ except Exception as e:
73
+ logger.error(f"Error plotting image prediction: {e}")
74
+ # Create empty figure if error occurs
75
+ fig, ax = plt.subplots(figsize=(8, 6))
76
+ ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center")
77
+ return fig
78
+
79
+
80
+ def create_heatmap_overlay(image, heatmap, alpha=0.4):
81
+ """
82
+ Create a heatmap overlay on an X-ray image to highlight areas of interest.
83
+
84
+ Args:
85
+ image (PIL.Image or str): Image or path to image
86
+ heatmap (numpy.ndarray): Heatmap array
87
+ alpha (float): Transparency of the overlay
88
+
89
+ Returns:
90
+ PIL.Image: Image with heatmap overlay
91
+ """
92
+ try:
93
+ # Load image if path is provided
94
+ if isinstance(image, str):
95
+ img = cv2.imread(image)
96
+ if img is None:
97
+ raise ValueError(f"Could not load image: {image}")
98
+ elif isinstance(image, Image.Image):
99
+ img = np.array(image)
100
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
101
+ else:
102
+ img = image
103
+
104
+ # Ensure image is in BGR format for OpenCV
105
+ if len(img.shape) == 2: # Grayscale
106
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
107
+
108
+ # Resize heatmap to match image dimensions
109
+ heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
110
+
111
+ # Normalize heatmap (0-1)
112
+ heatmap = np.maximum(heatmap, 0)
113
+ heatmap = np.minimum(heatmap / np.max(heatmap), 1)
114
+
115
+ # Apply colormap (jet) to heatmap
116
+ heatmap = np.uint8(255 * heatmap)
117
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
118
+
119
+ # Create overlay
120
+ overlay = cv2.addWeighted(img, 1 - alpha, heatmap, alpha, 0)
121
+
122
+ # Convert back to PIL image
123
+ overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
124
+ overlay_img = Image.fromarray(overlay)
125
+
126
+ return overlay_img
127
+
128
+ except Exception as e:
129
+ logger.error(f"Error creating heatmap overlay: {e}")
130
+ # Return original image if error occurs
131
+ if isinstance(image, str):
132
+ return Image.open(image)
133
+ elif isinstance(image, Image.Image):
134
+ return image
135
+ else:
136
+ return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
137
+
138
+
139
+ def plot_report_entities(text, entities, figsize=(12, 8)):
140
+ """
141
+ Visualize entities extracted from a medical report.
142
+
143
+ Args:
144
+ text (str): Report text
145
+ entities (dict): Dictionary of entities by category
146
+ figsize (tuple): Figure size
147
+
148
+ Returns:
149
+ matplotlib.figure.Figure: The figure object
150
+ """
151
+ try:
152
+ fig, ax = plt.subplots(figsize=figsize)
153
+ ax.axis("off")
154
+
155
+ # Set background color
156
+ fig.patch.set_facecolor("#f8f9fa")
157
+ ax.set_facecolor("#f8f9fa")
158
+
159
+ # Title
160
+ ax.text(
161
+ 0.5,
162
+ 0.98,
163
+ "Medical Report Analysis",
164
+ ha="center",
165
+ va="top",
166
+ fontsize=18,
167
+ fontweight="bold",
168
+ color="#2c3e50",
169
+ )
170
+
171
+ # Display entity counts
172
+ y_pos = 0.9
173
+ ax.text(
174
+ 0.05,
175
+ y_pos,
176
+ "Extracted Entities:",
177
+ fontsize=14,
178
+ fontweight="bold",
179
+ color="#2c3e50",
180
+ )
181
+ y_pos -= 0.05
182
+
183
+ # Define colors for different entity categories
184
+ category_colors = {
185
+ "problem": "#e74c3c", # Red
186
+ "test": "#3498db", # Blue
187
+ "treatment": "#2ecc71", # Green
188
+ "anatomy": "#9b59b6", # Purple
189
+ }
190
+
191
+ # Display entities by category
192
+ for category, items in entities.items():
193
+ if items:
194
+ y_pos -= 0.05
195
+ ax.text(
196
+ 0.1,
197
+ y_pos,
198
+ f"{category.capitalize()}:",
199
+ fontsize=12,
200
+ fontweight="bold",
201
+ )
202
+ y_pos -= 0.05
203
+ ax.text(
204
+ 0.15,
205
+ y_pos,
206
+ ", ".join(items),
207
+ wrap=True,
208
+ fontsize=11,
209
+ color=category_colors.get(category, "black"),
210
+ )
211
+
212
+ # Add the report text with highlighted entities
213
+ y_pos -= 0.1
214
+ ax.text(
215
+ 0.05,
216
+ y_pos,
217
+ "Report Text (with highlighted entities):",
218
+ fontsize=14,
219
+ fontweight="bold",
220
+ color="#2c3e50",
221
+ )
222
+ y_pos -= 0.05
223
+
224
+ # Get all entities to highlight
225
+ all_entities = []
226
+ for category, items in entities.items():
227
+ for item in items:
228
+ all_entities.append((item, category))
229
+
230
+ # Sort entities by length (longest first to avoid overlap issues)
231
+ all_entities.sort(key=lambda x: len(x[0]), reverse=True)
232
+
233
+ # Highlight entities in text
234
+ highlighted_text = text
235
+ for entity, category in all_entities:
236
+ # Escape regex special characters
237
+ entity_escaped = (
238
+ entity.replace("(", r"\(")
239
+ .replace(")", r"\)")
240
+ .replace("[", r"\[")
241
+ .replace("]", r"\]")
242
+ )
243
+
244
+ # Find entity in text (word boundary)
245
+ pattern = r"\b" + entity_escaped + r"\b"
246
+ color_code = category_colors.get(category, "black")
247
+ replacement = f"\\textcolor{{{color_code}}}{{{entity}}}"
248
+ highlighted_text = highlighted_text.replace(entity, replacement)
249
+
250
+ # Display highlighted text
251
+ ax.text(0.05, y_pos, highlighted_text, va="top", fontsize=10, wrap=True)
252
+
253
+ fig.tight_layout(rect=[0, 0.03, 1, 0.97])
254
+ return fig
255
+
256
+ except Exception as e:
257
+ logger.error(f"Error plotting report entities: {e}")
258
+ # Create empty figure if error occurs
259
+ fig, ax = plt.subplots(figsize=(8, 6))
260
+ ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center")
261
+ return fig
262
+
263
+
264
+ def plot_multimodal_results(
265
+ fused_results, image=None, report_text=None, figsize=(12, 10)
266
+ ):
267
+ """
268
+ Visualize the results of multimodal analysis.
269
+
270
+ Args:
271
+ fused_results (dict): Results from multimodal fusion
272
+ image (PIL.Image or str, optional): Image or path to image
273
+ report_text (str, optional): Report text
274
+ figsize (tuple): Figure size
275
+
276
+ Returns:
277
+ matplotlib.figure.Figure: The figure object
278
+ """
279
+ try:
280
+ # Create figure with a grid layout
281
+ fig = plt.figure(figsize=figsize)
282
+ gs = fig.add_gridspec(2, 2)
283
+
284
+ # Add title
285
+ fig.suptitle(
286
+ "Multimodal Medical Analysis Results",
287
+ fontsize=18,
288
+ fontweight="bold",
289
+ y=0.98,
290
+ )
291
+
292
+ # 1. Overview panel (top left)
293
+ ax_overview = fig.add_subplot(gs[0, 0])
294
+ ax_overview.axis("off")
295
+
296
+ # Get severity info
297
+ severity = fused_results.get("severity", {})
298
+ severity_level = severity.get("level", "Unknown")
299
+ severity_score = severity.get("score", 0)
300
+
301
+ # Get primary finding
302
+ primary_finding = fused_results.get("primary_finding", "Unknown")
303
+
304
+ # Get agreement score
305
+ agreement = fused_results.get("agreement_score", 0)
306
+
307
+ # Create overview text
308
+ overview_text = [
309
+ "ANALYSIS OVERVIEW",
310
+ f"Primary Finding: {primary_finding}",
311
+ f"Severity Level: {severity_level} ({severity_score}/4)",
312
+ f"Agreement Score: {agreement:.0%}",
313
+ ]
314
+
315
+ # Define severity colors
316
+ severity_colors = {
317
+ "Normal": "#2ecc71", # Green
318
+ "Mild": "#3498db", # Blue
319
+ "Moderate": "#f39c12", # Orange
320
+ "Severe": "#e74c3c", # Red
321
+ "Critical": "#c0392b", # Dark Red
322
+ }
323
+
324
+ # Add overview text to the panel
325
+ y_pos = 0.9
326
+ ax_overview.text(
327
+ 0.5,
328
+ y_pos,
329
+ overview_text[0],
330
+ fontsize=14,
331
+ fontweight="bold",
332
+ ha="center",
333
+ va="center",
334
+ )
335
+ y_pos -= 0.15
336
+
337
+ ax_overview.text(
338
+ 0.1, y_pos, overview_text[1], fontsize=12, ha="left", va="center"
339
+ )
340
+ y_pos -= 0.1
341
+
342
+ # Severity with color
343
+ severity_color = severity_colors.get(severity_level, "black")
344
+ ax_overview.text(
345
+ 0.1, y_pos, "Severity Level:", fontsize=12, ha="left", va="center"
346
+ )
347
+ ax_overview.text(
348
+ 0.4,
349
+ y_pos,
350
+ severity_level,
351
+ fontsize=12,
352
+ color=severity_color,
353
+ fontweight="bold",
354
+ ha="left",
355
+ va="center",
356
+ )
357
+ ax_overview.text(
358
+ 0.6, y_pos, f"({severity_score}/4)", fontsize=10, ha="left", va="center"
359
+ )
360
+ y_pos -= 0.1
361
+
362
+ # Agreement score with color
363
+ agreement_color = (
364
+ "#2ecc71"
365
+ if agreement > 0.7
366
+ else "#f39c12"
367
+ if agreement > 0.4
368
+ else "#e74c3c"
369
+ )
370
+ ax_overview.text(
371
+ 0.1, y_pos, "Agreement Score:", fontsize=12, ha="left", va="center"
372
+ )
373
+ ax_overview.text(
374
+ 0.4,
375
+ y_pos,
376
+ f"{agreement:.0%}",
377
+ fontsize=12,
378
+ color=agreement_color,
379
+ fontweight="bold",
380
+ ha="left",
381
+ va="center",
382
+ )
383
+
384
+ # 2. Findings panel (top right)
385
+ ax_findings = fig.add_subplot(gs[0, 1])
386
+ ax_findings.axis("off")
387
+
388
+ # Get findings
389
+ findings = fused_results.get("findings", [])
390
+
391
+ # Add findings to the panel
392
+ y_pos = 0.9
393
+ ax_findings.text(
394
+ 0.5,
395
+ y_pos,
396
+ "KEY FINDINGS",
397
+ fontsize=14,
398
+ fontweight="bold",
399
+ ha="center",
400
+ va="center",
401
+ )
402
+ y_pos -= 0.1
403
+
404
+ if findings:
405
+ for i, finding in enumerate(findings[:5]): # Limit to 5 findings
406
+ ax_findings.text(0.05, y_pos, "•", fontsize=14, ha="left", va="center")
407
+ ax_findings.text(
408
+ 0.1, y_pos, finding, fontsize=11, ha="left", va="center", wrap=True
409
+ )
410
+ y_pos -= 0.15
411
+ else:
412
+ ax_findings.text(
413
+ 0.1,
414
+ y_pos,
415
+ "No specific findings detailed.",
416
+ fontsize=11,
417
+ ha="left",
418
+ va="center",
419
+ )
420
+
421
+ # 3. Image panel (bottom left)
422
+ ax_image = fig.add_subplot(gs[1, 0])
423
+
424
+ if image is not None:
425
+ # Load image if path is provided
426
+ if isinstance(image, str):
427
+ img = Image.open(image)
428
+ else:
429
+ img = image
430
+
431
+ # Display image
432
+ ax_image.imshow(img)
433
+ ax_image.set_title("X-ray Image", fontsize=12)
434
+ else:
435
+ ax_image.text(0.5, 0.5, "No image available", ha="center", va="center")
436
+
437
+ ax_image.axis("off")
438
+
439
+ # 4. Recommendation panel (bottom right)
440
+ ax_rec = fig.add_subplot(gs[1, 1])
441
+ ax_rec.axis("off")
442
+
443
+ # Get recommendations
444
+ recommendations = fused_results.get("followup_recommendations", [])
445
+
446
+ # Add recommendations to the panel
447
+ y_pos = 0.9
448
+ ax_rec.text(
449
+ 0.5,
450
+ y_pos,
451
+ "RECOMMENDATIONS",
452
+ fontsize=14,
453
+ fontweight="bold",
454
+ ha="center",
455
+ va="center",
456
+ )
457
+ y_pos -= 0.1
458
+
459
+ if recommendations:
460
+ for i, rec in enumerate(recommendations):
461
+ ax_rec.text(0.05, y_pos, "•", fontsize=14, ha="left", va="center")
462
+ ax_rec.text(
463
+ 0.1, y_pos, rec, fontsize=11, ha="left", va="center", wrap=True
464
+ )
465
+ y_pos -= 0.15
466
+ else:
467
+ ax_rec.text(
468
+ 0.1,
469
+ y_pos,
470
+ "No specific recommendations provided.",
471
+ fontsize=11,
472
+ ha="left",
473
+ va="center",
474
+ )
475
+
476
+ # Add disclaimer
477
+ fig.text(
478
+ 0.5,
479
+ 0.03,
480
+ "DISCLAIMER: This analysis is for informational purposes only and should not replace professional medical advice.",
481
+ fontsize=9,
482
+ style="italic",
483
+ ha="center",
484
+ )
485
+
486
+ fig.tight_layout(rect=[0, 0.05, 1, 0.95])
487
+ return fig
488
+
489
+ except Exception as e:
490
+ logger.error(f"Error plotting multimodal results: {e}")
491
+ # Create empty figure if error occurs
492
+ fig, ax = plt.subplots(figsize=(8, 6))
493
+ ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center")
494
+ return fig
495
+
496
+
497
+ def figure_to_base64(fig):
498
+ """
499
+ Convert matplotlib figure to base64 string.
500
+
501
+ Args:
502
+ fig (matplotlib.figure.Figure): Figure object
503
+
504
+ Returns:
505
+ str: Base64 encoded string
506
+ """
507
+ try:
508
+ buf = io.BytesIO()
509
+ fig.savefig(buf, format="png", bbox_inches="tight")
510
+ buf.seek(0)
511
+ img_str = base64.b64encode(buf.read()).decode("utf-8")
512
+ return img_str
513
+
514
+ except Exception as e:
515
+ logger.error(f"Error converting figure to base64: {e}")
516
+ return ""