Spaces:
Running
Running
Divyanshu Tak
commited on
Commit
·
f5288df
1
Parent(s):
d85b08e
Initial commit of BrainIAC Docker application
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- .gitattributes +9 -0
- Dockerfile +49 -0
- README.md +15 -7
- requirements.txt +23 -0
- src/.DS_Store +0 -0
- src/BrainIAC/.DS_Store +0 -0
- src/BrainIAC/Brainage/README.md +55 -0
- src/BrainIAC/Brainage/__init__.py +0 -0
- src/BrainIAC/Brainage/__pycache__/__init__.cpython-39.pyc +0 -0
- src/BrainIAC/Brainage/__pycache__/infer_brainage.cpython-39.pyc +0 -0
- src/BrainIAC/Brainage/brainage.jpeg +0 -0
- src/BrainIAC/Brainage/infer_brainage.py +85 -0
- src/BrainIAC/Brainage/train_brainage.py +230 -0
- src/BrainIAC/HD_BET/__pycache__/config.cpython-310.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/config.cpython-38.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/config.cpython-39.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-310.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-38.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-39.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/hd_bet.cpython-310.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/hd_bet.cpython-38.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-310.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-38.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-39.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/paths.cpython-310.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/paths.cpython-38.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/paths.cpython-39.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-310.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-38.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-39.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/run.cpython-310.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/run.cpython-38.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/run.cpython-39.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/utils.cpython-310.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/utils.cpython-38.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/utils.cpython-39.pyc +0 -0
- src/BrainIAC/HD_BET/config.py +121 -0
- src/BrainIAC/HD_BET/data_loading.py +121 -0
- src/BrainIAC/HD_BET/hd_bet.py +119 -0
- src/BrainIAC/HD_BET/network_architecture.py +213 -0
- src/BrainIAC/HD_BET/paths.py +6 -0
- src/BrainIAC/HD_BET/predict_case.py +126 -0
- src/BrainIAC/HD_BET/run.py +117 -0
- src/BrainIAC/HD_BET/utils.py +115 -0
- src/BrainIAC/IDHprediction/README.md +53 -0
- src/BrainIAC/IDHprediction/__init__.py +0 -0
- src/BrainIAC/IDHprediction/__pycache__/__init__.cpython-39.pyc +0 -0
- src/BrainIAC/IDHprediction/__pycache__/infer_idh.cpython-39.pyc +0 -0
- src/BrainIAC/IDHprediction/idh.jpeg +0 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitattributes
CHANGED
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
src/BrainIAC/checkpoints/*.pt filter=lfs diff=lfs merge=lfs -text
|
37 |
+
src/BrainIAC/hdbet_model/**/*.pth filter=lfs diff=lfs merge=lfs -text
|
38 |
+
src/BrainIAC/golden_image/**/*.nii.gz filter=lfs diff=lfs merge=lfs -text
|
39 |
+
src/BrainIAC/golden_image/**/*.nii filter=lfs diff=lfs merge=lfs -text
|
40 |
+
src/BrainIAC/preprocessing/atlases/*.nii filter=lfs diff=lfs merge=lfs -text
|
41 |
+
src/BrainIAC/golden_image/mni_templates/*.nii filter=lfs diff=lfs merge=lfs -text
|
42 |
+
src/BrainIAC/hdbet_model/**/*.onnx filter=lfs diff=lfs merge=lfs -text
|
43 |
+
src/BrainIAC/golden_image/*.nii filter=lfs diff=lfs merge=lfs -text
|
44 |
+
/Users/divyanshutak/spaces/BrainIAC-Brainage-V0/src/BrainIAC/static/images/*.jpeg filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use an official Python runtime as a parent image
|
2 |
+
FROM python:3.10-slim
|
3 |
+
|
4 |
+
# Set the working directory in the container
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# Install necessary system dependencies
|
8 |
+
RUN apt-get update && \
|
9 |
+
apt-get install -y --no-install-recommends \
|
10 |
+
unzip \
|
11 |
+
git \
|
12 |
+
# Add potential ITK/build dependencies (might need adjustment)
|
13 |
+
build-essential \
|
14 |
+
cmake \
|
15 |
+
libgl1 \
|
16 |
+
libglib2.0-0 \
|
17 |
+
# libitk5-dev # Removed: Rely on pip install of itk-elastix
|
18 |
+
&& \
|
19 |
+
rm -rf /var/lib/apt/lists/*
|
20 |
+
|
21 |
+
# Copy the requirements file first to leverage Docker cache
|
22 |
+
COPY requirements.txt .
|
23 |
+
|
24 |
+
# Install Python packages specified in requirements.txt
|
25 |
+
# Using --no-cache-dir can reduce image size
|
26 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
27 |
+
|
28 |
+
# --- HD-BET is now copied locally via src/BrainIAC ---
|
29 |
+
|
30 |
+
# Copy the rest of the application code (including local HD_BET)
|
31 |
+
COPY src/BrainIAC /app/BrainIAC
|
32 |
+
|
33 |
+
# Copy static files (like images)
|
34 |
+
COPY src/BrainIAC/static /app/BrainIAC/static
|
35 |
+
|
36 |
+
# Copy the MNI templates and parameter files
|
37 |
+
COPY src/BrainIAC/golden_image /app/BrainIAC/golden_image
|
38 |
+
|
39 |
+
# Copy the HD-BET models
|
40 |
+
COPY src/BrainIAC/hdbet_model /app/BrainIAC/hdbet_model
|
41 |
+
|
42 |
+
# Copy the model checkpoint
|
43 |
+
COPY src/BrainIAC/checkpoints/brainage_model_latest.pt /app/BrainIAC/checkpoints/brainage_model_latest.pt
|
44 |
+
|
45 |
+
# Make port 5000 available
|
46 |
+
EXPOSE 5000
|
47 |
+
|
48 |
+
# Run app.py when the container launches using gunicorn
|
49 |
+
CMD ["gunicorn", "--chdir", "/app/BrainIAC", "--bind", "0.0.0.0:5000", "--timeout", "600", "app:app"]
|
README.md
CHANGED
@@ -1,12 +1,20 @@
|
|
1 |
---
|
2 |
-
title: BrainIAC
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: green
|
6 |
sdk: docker
|
7 |
-
|
8 |
-
license: cc-by-nc-sa-2.0
|
9 |
-
short_description: 'Brainage predictor '
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: BrainIAC Brain Age Prediction
|
3 |
+
emoji: 🧠
|
4 |
+
colorFrom: blue
|
5 |
colorTo: green
|
6 |
sdk: docker
|
7 |
+
app_port: 5000 # Make sure this matches the EXPOSE port in Dockerfile and gunicorn bind port
|
|
|
|
|
8 |
---
|
9 |
|
10 |
+
# BrainIAC: Brain Age Prediction Demo
|
11 |
+
|
12 |
+
This Hugging Face Space hosts an interactive demo for the BrainIAC model, predicting brain age from MRI scans.
|
13 |
+
|
14 |
+
**Features:**
|
15 |
+
- Upload NIfTI (.nii.gz) or DICOM (.zip) files.
|
16 |
+
- Optional preprocessing pipeline (Registration, Enhancement, Skull Stripping).
|
17 |
+
- Optional generation of saliency map visualizations.
|
18 |
+
- Predicts brain age in years based on the input scan.
|
19 |
+
|
20 |
+
The application runs inside a Docker container defined by the `Dockerfile`. It uses Flask, MONAI, HD-BET, and ITK/Elastix.
|
requirements.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pytorch-lightning==2.3.3
|
2 |
+
monai==1.3.2
|
3 |
+
nibabel==5.2.1
|
4 |
+
scikit-image==0.21.0
|
5 |
+
scikit-learn==1.2.2
|
6 |
+
scipy==1.10.1
|
7 |
+
seaborn==0.12.2
|
8 |
+
numpy==1.23.5
|
9 |
+
autograd==1.7.0
|
10 |
+
matplotlib==3.7.1
|
11 |
+
SimpleITK==2.4.0
|
12 |
+
tqdm
|
13 |
+
pydicom
|
14 |
+
wandb
|
15 |
+
lifelines
|
16 |
+
torch==2.6.0
|
17 |
+
opencv-python
|
18 |
+
pandas
|
19 |
+
Flask
|
20 |
+
gunicorn
|
21 |
+
PyYAML
|
22 |
+
dicom2nifti
|
23 |
+
itk-elastix
|
src/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
src/BrainIAC/.DS_Store
ADDED
Binary file (12.3 kB). View file
|
|
src/BrainIAC/Brainage/README.md
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Brain Age Prediction
|
2 |
+
|
3 |
+
<p align="left">
|
4 |
+
<img src="brainage.jpeg" width="200" alt="Brain Age Prediction Example"/>
|
5 |
+
</p>
|
6 |
+
|
7 |
+
## Overview
|
8 |
+
|
9 |
+
We present the brainage prediction training and inference code for BrainIAC as a downstream task. The pipeline is trained and infered on T1 scans, with MAE as evaluation metric.
|
10 |
+
|
11 |
+
## Data Requirements
|
12 |
+
|
13 |
+
- **Input**: T1-weighted MRI scans
|
14 |
+
- **Format**: NIFTI (.nii.gz)
|
15 |
+
- **Preprocessing**: Bias field corrected, registered to standard space, skull stripped
|
16 |
+
- **CSV Structure**:
|
17 |
+
```
|
18 |
+
pat_id,scandate,label
|
19 |
+
subject001,20240101,65 # brain age in years
|
20 |
+
```
|
21 |
+
refer to [ quickstart.ipynb](../quickstart.ipynb) to find how to preprocess data and generate csv file.
|
22 |
+
|
23 |
+
|
24 |
+
## Setup
|
25 |
+
|
26 |
+
1. **Configuration**:
|
27 |
+
change the [config.yml](../config.yml) file accordingly.
|
28 |
+
```yaml
|
29 |
+
# config.yml
|
30 |
+
data:
|
31 |
+
train_csv: "path/to/train.csv"
|
32 |
+
val_csv: "path/to/val.csv"
|
33 |
+
test_csv: "path/to/test.csv"
|
34 |
+
root_dir: "../data/sample/processed"
|
35 |
+
collate: 1 # single scan framework
|
36 |
+
|
37 |
+
checkpoints: "./checkpoints/brainage_model.00" # for inference/testing
|
38 |
+
|
39 |
+
train:
|
40 |
+
finetune: 'yes' # yes to finetune the entire model
|
41 |
+
freeze: 'no' # yes to freeze the resnet backbone
|
42 |
+
weights: ./checkpoints/brainiac.ckpt # path to brainiac weights
|
43 |
+
|
44 |
+
```
|
45 |
+
|
46 |
+
2. **Training**:
|
47 |
+
```bash
|
48 |
+
python -m Brainage.train_brainage
|
49 |
+
```
|
50 |
+
|
51 |
+
3. **Inference**:
|
52 |
+
```bash
|
53 |
+
python -m Brainage.infer_brainage
|
54 |
+
```
|
55 |
+
|
src/BrainIAC/Brainage/__init__.py
ADDED
File without changes
|
src/BrainIAC/Brainage/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (145 Bytes). View file
|
|
src/BrainIAC/Brainage/__pycache__/infer_brainage.cpython-39.pyc
ADDED
Binary file (3 kB). View file
|
|
src/BrainIAC/Brainage/brainage.jpeg
ADDED
![]() |
src/BrainIAC/Brainage/infer_brainage.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import pandas as pd
|
3 |
+
import os
|
4 |
+
from tqdm import tqdm
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from torch.cuda.amp import autocast
|
7 |
+
from sklearn.metrics import mean_absolute_error
|
8 |
+
import sys
|
9 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
10 |
+
from dataset2 import MedicalImageDatasetBalancedIntensity3D
|
11 |
+
from model import Backbone, SingleScanModel, Classifier
|
12 |
+
from utils import BaseConfig
|
13 |
+
|
14 |
+
class BrainAgeInference(BaseConfig):
|
15 |
+
"""
|
16 |
+
Inference class for brain age prediction model.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self):
|
20 |
+
"""Initialize the inference setup with model and data."""
|
21 |
+
super().__init__()
|
22 |
+
self.setup_model()
|
23 |
+
self.setup_data()
|
24 |
+
|
25 |
+
def setup_model(self):
|
26 |
+
config = self.get_config()
|
27 |
+
self.backbone = Backbone()
|
28 |
+
self.classifier = Classifier(d_model=2048)
|
29 |
+
self.model = SingleScanModel(self.backbone, self.classifier)
|
30 |
+
|
31 |
+
# Load weights
|
32 |
+
checkpoint = torch.load(config["infer"]["checkpoints"], map_location=self.device)
|
33 |
+
self.model.load_state_dict(checkpoint["model_state_dict"])
|
34 |
+
self.model = self.model.to(self.device)
|
35 |
+
self.model.eval()
|
36 |
+
print("Model and checkpoint loaded!")
|
37 |
+
|
38 |
+
## spinup dataloaders
|
39 |
+
def setup_data(self):
|
40 |
+
config = self.get_config()
|
41 |
+
self.test_dataset = MedicalImageDatasetBalancedIntensity3D(
|
42 |
+
csv_path=config["data"]["test_csv"],
|
43 |
+
root_dir=config["data"]["root_dir"]
|
44 |
+
)
|
45 |
+
self.test_loader = DataLoader(
|
46 |
+
self.test_dataset,
|
47 |
+
batch_size=1,
|
48 |
+
shuffle=False,
|
49 |
+
collate_fn=self.custom_collate,
|
50 |
+
num_workers=1
|
51 |
+
)
|
52 |
+
|
53 |
+
def infer(self):
|
54 |
+
""" Infer pass """
|
55 |
+
results_df = pd.DataFrame(columns=['PredictedAge', 'TrueAge'])
|
56 |
+
all_labels = []
|
57 |
+
all_predictions = []
|
58 |
+
|
59 |
+
with torch.no_grad():
|
60 |
+
for sample in tqdm(self.test_loader, desc="Inference", unit="batch"):
|
61 |
+
inputs = sample['image'].to(self.device)
|
62 |
+
labels = sample['label'].float().to(self.device)
|
63 |
+
|
64 |
+
with autocast():
|
65 |
+
outputs = self.model(inputs)
|
66 |
+
|
67 |
+
predictions = outputs.cpu().numpy().flatten()
|
68 |
+
all_labels.extend(labels.cpu().numpy().flatten())
|
69 |
+
all_predictions.extend(predictions)
|
70 |
+
|
71 |
+
result = pd.DataFrame({
|
72 |
+
'PredictedAge': predictions,
|
73 |
+
'TrueAge': labels.cpu().numpy().flatten()
|
74 |
+
})
|
75 |
+
results_df = pd.concat([results_df, result], ignore_index=True)
|
76 |
+
|
77 |
+
mae = mean_absolute_error(all_labels, all_predictions)
|
78 |
+
print(f"Mean Absolute Error (MAE): {mae:.4f} months")
|
79 |
+
results_df.to_csv('./data/output/brainage_output.csv', index=False)
|
80 |
+
|
81 |
+
return mae
|
82 |
+
|
83 |
+
if __name__ == "__main__":
|
84 |
+
inferencer = BrainAgeInference()
|
85 |
+
mae = inferencer.infer()
|
src/BrainIAC/Brainage/train_brainage.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.optim as optim
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
import wandb
|
6 |
+
from tqdm import tqdm
|
7 |
+
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
8 |
+
from torch.cuda.amp import GradScaler, autocast
|
9 |
+
from sklearn.metrics import mean_absolute_error
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
13 |
+
from dataset2 import MedicalImageDatasetBalancedIntensity3D, TransformationMedicalImageDatasetBalancedIntensity3D
|
14 |
+
from model import Backbone, SingleScanModel, Classifier
|
15 |
+
from utils import BaseConfig
|
16 |
+
|
17 |
+
|
18 |
+
class BrainAgeTrainer(BaseConfig):
|
19 |
+
"""
|
20 |
+
A trainer class for brain age prediction models.
|
21 |
+
|
22 |
+
This class handles the complete training pipeline including model setup,
|
23 |
+
data loading, training loop, and validation.
|
24 |
+
Inherits from BaseConfig for configuration management.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self):
|
28 |
+
"""Initialize the trainer with model, data, and training setup."""
|
29 |
+
super().__init__()
|
30 |
+
self.setup_wandb()
|
31 |
+
self.setup_model()
|
32 |
+
self.setup_data()
|
33 |
+
self.setup_training()
|
34 |
+
|
35 |
+
## setup wandb logger
|
36 |
+
def setup_wandb(self):
|
37 |
+
config = self.get_config()
|
38 |
+
wandb.init(
|
39 |
+
project=config['logger']['project_name'],
|
40 |
+
name=config['logger']['run_name'],
|
41 |
+
config=config
|
42 |
+
)
|
43 |
+
|
44 |
+
def setup_model(self):
|
45 |
+
"""
|
46 |
+
Set up the model architecture.
|
47 |
+
|
48 |
+
Initializes the backbone and classifier blocks, and loads
|
49 |
+
checkpoints
|
50 |
+
"""
|
51 |
+
self.backbone = Backbone()
|
52 |
+
self.classifier = Classifier(d_model=2048)
|
53 |
+
self.model = SingleScanModel(self.backbone, self.classifier)
|
54 |
+
|
55 |
+
# Load BrainIACs weights
|
56 |
+
config = self.get_config()
|
57 |
+
if config["train"]["finetune"] == "yes":
|
58 |
+
checkpoint = torch.load(config["train"]["weights"], map_location=self.device)
|
59 |
+
state_dict = checkpoint["state_dict"]
|
60 |
+
filtered_state_dict = {}
|
61 |
+
for key, value in state_dict.items():
|
62 |
+
new_key = key.replace("module.", "backbone.") if key.startswith("module.") else key
|
63 |
+
filtered_state_dict[new_key] = value
|
64 |
+
self.model.backbone.load_state_dict(filtered_state_dict, strict=False)
|
65 |
+
print("Pretrained weights loaded!")
|
66 |
+
|
67 |
+
# Freeze backbone if specified
|
68 |
+
if config["train"]["freeze"] == "yes":
|
69 |
+
for param in self.model.backbone.parameters():
|
70 |
+
param.requires_grad = False
|
71 |
+
print("Backbone weights frozen!")
|
72 |
+
|
73 |
+
self.model = self.model.to(self.device)
|
74 |
+
|
75 |
+
def setup_data(self):
|
76 |
+
"""
|
77 |
+
Set up data loaders for training and validation.
|
78 |
+
Inherit configuration from the base config
|
79 |
+
"""
|
80 |
+
config = self.get_config()
|
81 |
+
self.train_dataset = TransformationMedicalImageDatasetBalancedIntensity3D(
|
82 |
+
csv_path=config['data']['train_csv'],
|
83 |
+
root_dir=config["data"]["root_dir"]
|
84 |
+
)
|
85 |
+
self.val_dataset = MedicalImageDatasetBalancedIntensity3D(
|
86 |
+
csv_path=config['data']['val_csv'],
|
87 |
+
root_dir=config["data"]["root_dir"]
|
88 |
+
)
|
89 |
+
|
90 |
+
self.train_loader = DataLoader(
|
91 |
+
self.train_dataset,
|
92 |
+
batch_size=config["data"]["batch_size"],
|
93 |
+
shuffle=True,
|
94 |
+
collate_fn=self.custom_collate,
|
95 |
+
num_workers=config["data"]["num_workers"]
|
96 |
+
)
|
97 |
+
self.val_loader = DataLoader(
|
98 |
+
self.val_dataset,
|
99 |
+
batch_size=1,
|
100 |
+
shuffle=False,
|
101 |
+
collate_fn=self.custom_collate,
|
102 |
+
num_workers=1
|
103 |
+
)
|
104 |
+
|
105 |
+
def setup_training(self):
|
106 |
+
"""
|
107 |
+
Set up training config with loss, scheduler, optimizer.
|
108 |
+
"""
|
109 |
+
config = self.get_config()
|
110 |
+
self.criterion = nn.MSELoss()
|
111 |
+
self.optimizer = optim.Adam(
|
112 |
+
self.model.parameters(),
|
113 |
+
lr=config['optim']['lr'],
|
114 |
+
weight_decay=config["optim"]["weight_decay"]
|
115 |
+
)
|
116 |
+
self.scheduler = CosineAnnealingWarmRestarts(self.optimizer, T_0=50, T_mult=2)
|
117 |
+
self.scaler = GradScaler()
|
118 |
+
|
119 |
+
def train(self):
|
120 |
+
"""
|
121 |
+
main training loop
|
122 |
+
"""
|
123 |
+
config = self.get_config()
|
124 |
+
max_epochs = config['optim']['max_epochs']
|
125 |
+
best_val_loss = float('inf')
|
126 |
+
best_val_mae = float('inf')
|
127 |
+
|
128 |
+
for epoch in range(max_epochs):
|
129 |
+
train_loss = self.train_epoch(epoch, max_epochs)
|
130 |
+
val_loss, mae = self.validate_epoch(epoch, max_epochs)
|
131 |
+
|
132 |
+
# Save best model
|
133 |
+
if (val_loss <= best_val_loss) and (mae <= best_val_mae):
|
134 |
+
print(f"Improved Val Loss from {best_val_loss:.4f} to {val_loss:.4f}")
|
135 |
+
print(f"Improved Val MAE from {best_val_mae:.4f} to {mae:.4f}")
|
136 |
+
best_val_loss = val_loss
|
137 |
+
best_val_mae = mae
|
138 |
+
self.save_checkpoint(epoch, val_loss, mae)
|
139 |
+
|
140 |
+
wandb.finish()
|
141 |
+
|
142 |
+
def train_epoch(self, epoch, max_epochs):
|
143 |
+
"""
|
144 |
+
Train pass.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
epoch (int): Current epoch number
|
148 |
+
max_epochs (int): Total number of epochs
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
float: Average training loss for the epoch
|
152 |
+
"""
|
153 |
+
self.model.train()
|
154 |
+
train_loss = 0.0
|
155 |
+
|
156 |
+
for sample in tqdm(self.train_loader, desc=f"Training Epoch {epoch}/{max_epochs-1}"):
|
157 |
+
inputs = sample['image'].to(self.device)
|
158 |
+
labels = sample['label'].float().to(self.device)
|
159 |
+
|
160 |
+
self.optimizer.zero_grad()
|
161 |
+
with autocast():
|
162 |
+
outputs = self.model(inputs)
|
163 |
+
loss = self.criterion(outputs, labels.unsqueeze(1))
|
164 |
+
|
165 |
+
self.scaler.scale(loss).backward()
|
166 |
+
self.scaler.step(self.optimizer)
|
167 |
+
self.scaler.update()
|
168 |
+
|
169 |
+
train_loss += loss.item() * inputs.size(0)
|
170 |
+
|
171 |
+
train_loss = train_loss / len(self.train_loader.dataset)
|
172 |
+
wandb.log({"Train Loss": train_loss})
|
173 |
+
return train_loss
|
174 |
+
|
175 |
+
def validate_epoch(self, epoch, max_epochs):
|
176 |
+
"""
|
177 |
+
Validation pass.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
epoch (int): Current epoch number
|
181 |
+
max_epochs (int): Total number of epochs
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
tuple: (validation_loss, mean_absolute_error)
|
185 |
+
"""
|
186 |
+
self.model.eval()
|
187 |
+
val_loss = 0.0
|
188 |
+
all_labels = []
|
189 |
+
all_preds = []
|
190 |
+
|
191 |
+
with torch.no_grad():
|
192 |
+
for sample in tqdm(self.val_loader, desc=f"Validation Epoch {epoch}/{max_epochs-1}"):
|
193 |
+
inputs = sample['image'].to(self.device)
|
194 |
+
labels = sample['label'].float().to(self.device)
|
195 |
+
|
196 |
+
outputs = self.model(inputs)
|
197 |
+
loss = self.criterion(outputs, labels.unsqueeze(1))
|
198 |
+
|
199 |
+
val_loss += loss.item() * inputs.size(0)
|
200 |
+
all_labels.extend(labels.cpu().numpy().flatten())
|
201 |
+
all_preds.extend(outputs.cpu().numpy().flatten())
|
202 |
+
|
203 |
+
val_loss = val_loss / len(self.val_loader.dataset)
|
204 |
+
mae = mean_absolute_error(all_labels, all_preds)
|
205 |
+
|
206 |
+
wandb.log({"Val Loss": val_loss, "MAE": mae})
|
207 |
+
self.scheduler.step(val_loss)
|
208 |
+
|
209 |
+
print(f"Epoch {epoch}/{max_epochs-1} Val Loss: {val_loss:.4f} MAE: {mae:.4f}")
|
210 |
+
return val_loss, mae
|
211 |
+
|
212 |
+
def save_checkpoint(self, epoch, loss, mae):
|
213 |
+
"""
|
214 |
+
Save model checkpoint.
|
215 |
+
"""
|
216 |
+
config = self.get_config()
|
217 |
+
checkpoint = {
|
218 |
+
'model_state_dict': self.model.state_dict(),
|
219 |
+
'loss': loss,
|
220 |
+
'epoch': epoch,
|
221 |
+
}
|
222 |
+
save_path = os.path.join(
|
223 |
+
config['logger']['save_dir'],
|
224 |
+
config['logger']['save_name'].format(epoch=epoch, loss=loss, metric=mae)
|
225 |
+
)
|
226 |
+
torch.save(checkpoint, save_path)
|
227 |
+
|
228 |
+
if __name__ == "__main__":
|
229 |
+
trainer = BrainAgeTrainer()
|
230 |
+
trainer.train()
|
src/BrainIAC/HD_BET/__pycache__/config.cpython-310.pyc
ADDED
Binary file (4.15 kB). View file
|
|
src/BrainIAC/HD_BET/__pycache__/config.cpython-38.pyc
ADDED
Binary file (4.13 kB). View file
|
|
src/BrainIAC/HD_BET/__pycache__/config.cpython-39.pyc
ADDED
Binary file (4.19 kB). View file
|
|
src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-310.pyc
ADDED
Binary file (4.47 kB). View file
|
|
src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-38.pyc
ADDED
Binary file (4.48 kB). View file
|
|
src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-39.pyc
ADDED
Binary file (4.46 kB). View file
|
|
src/BrainIAC/HD_BET/__pycache__/hd_bet.cpython-310.pyc
ADDED
Binary file (4.21 kB). View file
|
|
src/BrainIAC/HD_BET/__pycache__/hd_bet.cpython-38.pyc
ADDED
Binary file (4.27 kB). View file
|
|
src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-310.pyc
ADDED
Binary file (6.78 kB). View file
|
|
src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-38.pyc
ADDED
Binary file (6.89 kB). View file
|
|
src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-39.pyc
ADDED
Binary file (6.84 kB). View file
|
|
src/BrainIAC/HD_BET/__pycache__/paths.cpython-310.pyc
ADDED
Binary file (324 Bytes). View file
|
|
src/BrainIAC/HD_BET/__pycache__/paths.cpython-38.pyc
ADDED
Binary file (335 Bytes). View file
|
|
src/BrainIAC/HD_BET/__pycache__/paths.cpython-39.pyc
ADDED
Binary file (322 Bytes). View file
|
|
src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-310.pyc
ADDED
Binary file (3.68 kB). View file
|
|
src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-38.pyc
ADDED
Binary file (3.67 kB). View file
|
|
src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-39.pyc
ADDED
Binary file (3.68 kB). View file
|
|
src/BrainIAC/HD_BET/__pycache__/run.cpython-310.pyc
ADDED
Binary file (3.83 kB). View file
|
|
src/BrainIAC/HD_BET/__pycache__/run.cpython-38.pyc
ADDED
Binary file (3.88 kB). View file
|
|
src/BrainIAC/HD_BET/__pycache__/run.cpython-39.pyc
ADDED
Binary file (3.85 kB). View file
|
|
src/BrainIAC/HD_BET/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (4.68 kB). View file
|
|
src/BrainIAC/HD_BET/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (4.85 kB). View file
|
|
src/BrainIAC/HD_BET/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (4.81 kB). View file
|
|
src/BrainIAC/HD_BET/config.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from HD_BET.utils import SetNetworkToVal, softmax_helper
|
4 |
+
from abc import abstractmethod
|
5 |
+
from HD_BET.network_architecture import Network
|
6 |
+
|
7 |
+
|
8 |
+
class BaseConfig(object):
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
@abstractmethod
|
13 |
+
def get_split(self, fold, random_state=12345):
|
14 |
+
pass
|
15 |
+
|
16 |
+
@abstractmethod
|
17 |
+
def get_network(self, mode="train"):
|
18 |
+
pass
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
def get_basic_generators(self, fold):
|
22 |
+
pass
|
23 |
+
|
24 |
+
@abstractmethod
|
25 |
+
def get_data_generators(self, fold):
|
26 |
+
pass
|
27 |
+
|
28 |
+
def preprocess(self, data):
|
29 |
+
return data
|
30 |
+
|
31 |
+
def __repr__(self):
|
32 |
+
res = ""
|
33 |
+
for v in vars(self):
|
34 |
+
if not v.startswith("__") and not v.startswith("_") and v != 'dataset':
|
35 |
+
res += (v + ": " + str(self.__getattribute__(v)) + "\n")
|
36 |
+
return res
|
37 |
+
|
38 |
+
|
39 |
+
class HD_BET_Config(BaseConfig):
|
40 |
+
def __init__(self):
|
41 |
+
super(HD_BET_Config, self).__init__()
|
42 |
+
|
43 |
+
self.EXPERIMENT_NAME = self.__class__.__name__ # just a generic experiment name
|
44 |
+
|
45 |
+
# network parameters
|
46 |
+
self.net_base_num_layers = 21
|
47 |
+
self.BATCH_SIZE = 2
|
48 |
+
self.net_do_DS = True
|
49 |
+
self.net_dropout_p = 0.0
|
50 |
+
self.net_use_inst_norm = True
|
51 |
+
self.net_conv_use_bias = True
|
52 |
+
self.net_norm_use_affine = True
|
53 |
+
self.net_leaky_relu_slope = 1e-1
|
54 |
+
|
55 |
+
# hyperparameters
|
56 |
+
self.INPUT_PATCH_SIZE = (128, 128, 128)
|
57 |
+
self.num_classes = 2
|
58 |
+
self.selected_data_channels = range(1)
|
59 |
+
|
60 |
+
# data augmentation
|
61 |
+
self.da_mirror_axes = (2, 3, 4)
|
62 |
+
|
63 |
+
# validation
|
64 |
+
self.val_use_DO = False
|
65 |
+
self.val_use_train_mode = False # for dropout sampling
|
66 |
+
self.val_num_repeats = 1 # only useful if dropout sampling
|
67 |
+
self.val_batch_size = 1 # only useful if dropout sampling
|
68 |
+
self.val_save_npz = True
|
69 |
+
self.val_do_mirroring = True # test time data augmentation via mirroring
|
70 |
+
self.val_write_images = True
|
71 |
+
self.net_input_must_be_divisible_by = 16 # we could make a network class that has this as a property
|
72 |
+
self.val_min_size = self.INPUT_PATCH_SIZE
|
73 |
+
self.val_fn = None
|
74 |
+
|
75 |
+
# CAREFUL! THIS IS A HACK TO MAKE PYTORCH 0.3 STATE DICTS COMPATIBLE WITH PYTORCH 0.4 (setting keep_runnings_
|
76 |
+
# stats=True but not using them in validation. keep_runnings_stats was True before 0.3 but unused and defaults
|
77 |
+
# to false in 0.4)
|
78 |
+
self.val_use_moving_averages = False
|
79 |
+
|
80 |
+
def get_network(self, train=True, pretrained_weights=None):
|
81 |
+
net = Network(self.num_classes, len(self.selected_data_channels), self.net_base_num_layers,
|
82 |
+
self.net_dropout_p, softmax_helper, self.net_leaky_relu_slope, self.net_conv_use_bias,
|
83 |
+
self.net_norm_use_affine, True, self.net_do_DS)
|
84 |
+
|
85 |
+
if pretrained_weights is not None:
|
86 |
+
net.load_state_dict(
|
87 |
+
torch.load(pretrained_weights, map_location=lambda storage, loc: storage))
|
88 |
+
|
89 |
+
if train:
|
90 |
+
net.train(True)
|
91 |
+
else:
|
92 |
+
net.train(False)
|
93 |
+
net.apply(SetNetworkToVal(self.val_use_DO, self.val_use_moving_averages))
|
94 |
+
net.do_ds = False
|
95 |
+
|
96 |
+
optimizer = None
|
97 |
+
self.lr_scheduler = None
|
98 |
+
return net, optimizer
|
99 |
+
|
100 |
+
def get_data_generators(self, fold):
|
101 |
+
pass
|
102 |
+
|
103 |
+
def get_split(self, fold, random_state=12345):
|
104 |
+
pass
|
105 |
+
|
106 |
+
def get_basic_generators(self, fold):
|
107 |
+
pass
|
108 |
+
|
109 |
+
def on_epoch_end(self, epoch):
|
110 |
+
pass
|
111 |
+
|
112 |
+
def preprocess(self, data):
|
113 |
+
data = np.copy(data)
|
114 |
+
for c in range(data.shape[0]):
|
115 |
+
data[c] -= data[c].mean()
|
116 |
+
data[c] /= data[c].std()
|
117 |
+
return data
|
118 |
+
|
119 |
+
|
120 |
+
config = HD_BET_Config
|
121 |
+
|
src/BrainIAC/HD_BET/data_loading.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import SimpleITK as sitk
|
2 |
+
import numpy as np
|
3 |
+
from skimage.transform import resize
|
4 |
+
|
5 |
+
|
6 |
+
def resize_image(image, old_spacing, new_spacing, order=3):
|
7 |
+
new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))),
|
8 |
+
int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))),
|
9 |
+
int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2]))))
|
10 |
+
return resize(image, new_shape, order=order, mode='edge', cval=0, anti_aliasing=False)
|
11 |
+
|
12 |
+
|
13 |
+
def preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.5, 0.5)):
|
14 |
+
spacing = np.array(itk_image.GetSpacing())[[2, 1, 0]]
|
15 |
+
image = sitk.GetArrayFromImage(itk_image).astype(float)
|
16 |
+
|
17 |
+
assert len(image.shape) == 3, "The image has unsupported number of dimensions. Only 3D images are allowed"
|
18 |
+
|
19 |
+
if not is_seg:
|
20 |
+
if np.any([[i != j] for i, j in zip(spacing, spacing_target)]):
|
21 |
+
image = resize_image(image, spacing, spacing_target).astype(np.float32)
|
22 |
+
|
23 |
+
image -= image.mean()
|
24 |
+
image /= image.std()
|
25 |
+
else:
|
26 |
+
new_shape = (int(np.round(spacing[0] / spacing_target[0] * float(image.shape[0]))),
|
27 |
+
int(np.round(spacing[1] / spacing_target[1] * float(image.shape[1]))),
|
28 |
+
int(np.round(spacing[2] / spacing_target[2] * float(image.shape[2]))))
|
29 |
+
image = resize_segmentation(image, new_shape, 1)
|
30 |
+
return image
|
31 |
+
|
32 |
+
|
33 |
+
def load_and_preprocess(mri_file):
|
34 |
+
images = {}
|
35 |
+
# t1
|
36 |
+
images["T1"] = sitk.ReadImage(mri_file)
|
37 |
+
|
38 |
+
properties_dict = {
|
39 |
+
"spacing": images["T1"].GetSpacing(),
|
40 |
+
"direction": images["T1"].GetDirection(),
|
41 |
+
"size": images["T1"].GetSize(),
|
42 |
+
"origin": images["T1"].GetOrigin()
|
43 |
+
}
|
44 |
+
|
45 |
+
for k in images.keys():
|
46 |
+
images[k] = preprocess_image(images[k], is_seg=False, spacing_target=(1.5, 1.5, 1.5))
|
47 |
+
|
48 |
+
properties_dict['size_before_cropping'] = images["T1"].shape
|
49 |
+
|
50 |
+
imgs = []
|
51 |
+
for seq in ['T1']:
|
52 |
+
imgs.append(images[seq][None])
|
53 |
+
all_data = np.vstack(imgs)
|
54 |
+
print("image shape after preprocessing: ", str(all_data[0].shape))
|
55 |
+
return all_data, properties_dict
|
56 |
+
|
57 |
+
|
58 |
+
def save_segmentation_nifti(segmentation, dct, out_fname, order=1):
|
59 |
+
'''
|
60 |
+
segmentation must have the same spacing as the original nifti (for now). segmentation may have been cropped out
|
61 |
+
of the original image
|
62 |
+
|
63 |
+
dct:
|
64 |
+
size_before_cropping
|
65 |
+
brain_bbox
|
66 |
+
size -> this is the original size of the dataset, if the image was not resampled, this is the same as size_before_cropping
|
67 |
+
spacing
|
68 |
+
origin
|
69 |
+
direction
|
70 |
+
|
71 |
+
:param segmentation:
|
72 |
+
:param dct:
|
73 |
+
:param out_fname:
|
74 |
+
:return:
|
75 |
+
'''
|
76 |
+
old_size = dct.get('size_before_cropping')
|
77 |
+
bbox = dct.get('brain_bbox')
|
78 |
+
if bbox is not None:
|
79 |
+
seg_old_size = np.zeros(old_size)
|
80 |
+
for c in range(3):
|
81 |
+
bbox[c][1] = np.min((bbox[c][0] + segmentation.shape[c], old_size[c]))
|
82 |
+
seg_old_size[bbox[0][0]:bbox[0][1],
|
83 |
+
bbox[1][0]:bbox[1][1],
|
84 |
+
bbox[2][0]:bbox[2][1]] = segmentation
|
85 |
+
else:
|
86 |
+
seg_old_size = segmentation
|
87 |
+
if np.any(np.array(seg_old_size) != np.array(dct['size'])[[2, 1, 0]]):
|
88 |
+
seg_old_spacing = resize_segmentation(seg_old_size, np.array(dct['size'])[[2, 1, 0]], order=order)
|
89 |
+
else:
|
90 |
+
seg_old_spacing = seg_old_size
|
91 |
+
seg_resized_itk = sitk.GetImageFromArray(seg_old_spacing.astype(np.int32))
|
92 |
+
seg_resized_itk.SetSpacing(np.array(dct['spacing'])[[0, 1, 2]])
|
93 |
+
seg_resized_itk.SetOrigin(dct['origin'])
|
94 |
+
seg_resized_itk.SetDirection(dct['direction'])
|
95 |
+
sitk.WriteImage(seg_resized_itk, out_fname)
|
96 |
+
|
97 |
+
|
98 |
+
def resize_segmentation(segmentation, new_shape, order=3, cval=0):
|
99 |
+
'''
|
100 |
+
Taken from batchgenerators (https://github.com/MIC-DKFZ/batchgenerators) to prevent dependency
|
101 |
+
|
102 |
+
Resizes a segmentation map. Supports all orders (see skimage documentation). Will transform segmentation map to one
|
103 |
+
hot encoding which is resized and transformed back to a segmentation map.
|
104 |
+
This prevents interpolation artifacts ([0, 0, 2] -> [0, 1, 2])
|
105 |
+
:param segmentation:
|
106 |
+
:param new_shape:
|
107 |
+
:param order:
|
108 |
+
:return:
|
109 |
+
'''
|
110 |
+
tpe = segmentation.dtype
|
111 |
+
unique_labels = np.unique(segmentation)
|
112 |
+
assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation"
|
113 |
+
if order == 0:
|
114 |
+
return resize(segmentation, new_shape, order, mode="constant", cval=cval, clip=True, anti_aliasing=False).astype(tpe)
|
115 |
+
else:
|
116 |
+
reshaped = np.zeros(new_shape, dtype=segmentation.dtype)
|
117 |
+
|
118 |
+
for i, c in enumerate(unique_labels):
|
119 |
+
reshaped_multihot = resize((segmentation == c).astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False)
|
120 |
+
reshaped[reshaped_multihot >= 0.5] = c
|
121 |
+
return reshaped
|
src/BrainIAC/HD_BET/hd_bet.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
sys.path.append("/mnt/93E8-0534/AIDAN/HDBET/")
|
6 |
+
from HD_BET.run import run_hd_bet
|
7 |
+
from HD_BET.utils import maybe_mkdir_p, subfiles
|
8 |
+
import HD_BET
|
9 |
+
|
10 |
+
def hd_bet(input_file_or_dir,output_file_or_dir,mode,device,tta,pp=1,save_mask=0,overwrite_existing=1):
|
11 |
+
|
12 |
+
if output_file_or_dir is None:
|
13 |
+
output_file_or_dir = os.path.join(os.path.dirname(input_file_or_dir),
|
14 |
+
os.path.basename(input_file_or_dir).split(".")[0] + "_bet")
|
15 |
+
|
16 |
+
|
17 |
+
params_file = os.path.join(HD_BET.__path__[0], "model_final.py")
|
18 |
+
config_file = os.path.join(HD_BET.__path__[0], "config.py")
|
19 |
+
|
20 |
+
assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input"
|
21 |
+
|
22 |
+
if device == 'cpu':
|
23 |
+
pass
|
24 |
+
else:
|
25 |
+
device = int(device)
|
26 |
+
|
27 |
+
if os.path.isdir(input_file_or_dir):
|
28 |
+
maybe_mkdir_p(output_file_or_dir)
|
29 |
+
input_files = subfiles(input_file_or_dir, suffix='_0000.nii.gz', join=False)
|
30 |
+
|
31 |
+
if len(input_files) == 0:
|
32 |
+
raise RuntimeError("input is a folder but no nifti files (.nii.gz) were found in here")
|
33 |
+
|
34 |
+
output_files = [os.path.join(output_file_or_dir, i) for i in input_files]
|
35 |
+
input_files = [os.path.join(input_file_or_dir, i) for i in input_files]
|
36 |
+
else:
|
37 |
+
if not output_file_or_dir.endswith('.nii.gz'):
|
38 |
+
output_file_or_dir += '.nii.gz'
|
39 |
+
assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input"
|
40 |
+
|
41 |
+
output_files = [output_file_or_dir]
|
42 |
+
input_files = [input_file_or_dir]
|
43 |
+
|
44 |
+
if tta == 0:
|
45 |
+
tta = False
|
46 |
+
elif tta == 1:
|
47 |
+
tta = True
|
48 |
+
else:
|
49 |
+
raise ValueError("Unknown value for tta: %s. Expected: 0 or 1" % str(tta))
|
50 |
+
|
51 |
+
if overwrite_existing == 0:
|
52 |
+
overwrite_existing = False
|
53 |
+
elif overwrite_existing == 1:
|
54 |
+
overwrite_existing = True
|
55 |
+
else:
|
56 |
+
raise ValueError("Unknown value for overwrite_existing: %s. Expected: 0 or 1" % str(overwrite_existing))
|
57 |
+
|
58 |
+
if pp == 0:
|
59 |
+
pp = False
|
60 |
+
elif pp == 1:
|
61 |
+
pp = True
|
62 |
+
else:
|
63 |
+
raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp))
|
64 |
+
|
65 |
+
if save_mask == 0:
|
66 |
+
save_mask = False
|
67 |
+
elif save_mask == 1:
|
68 |
+
save_mask = True
|
69 |
+
else:
|
70 |
+
raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp))
|
71 |
+
|
72 |
+
run_hd_bet(input_files, output_files, mode, config_file, device, pp, tta, save_mask, overwrite_existing)
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
print("\n########################")
|
77 |
+
print("If you are using hd-bet, please cite the following paper:")
|
78 |
+
print("Isensee F, Schell M, Tursunova I, Brugnara G, Bonekamp D, Neuberger U, Wick A, Schlemmer HP, Heiland S, Wick W,"
|
79 |
+
"Bendszus M, Maier-Hein KH, Kickingereder P. Automated brain extraction of multi-sequence MRI using artificial"
|
80 |
+
"neural networks. arXiv preprint arXiv:1901.11341, 2019.")
|
81 |
+
print("########################\n")
|
82 |
+
|
83 |
+
import argparse
|
84 |
+
parser = argparse.ArgumentParser()
|
85 |
+
parser.add_argument('-i', '--input', help='input. Can be either a single file name or an input folder. If file: must be '
|
86 |
+
'nifti (.nii.gz) and can only be 3D. No support for 4d images, use fslsplit to '
|
87 |
+
'split 4d sequences into 3d images. If folder: all files ending with .nii.gz '
|
88 |
+
'within that folder will be brain extracted.', required=True, type=str)
|
89 |
+
parser.add_argument('-o', '--output', help='output. Can be either a filename or a folder. If it does not exist, the folder'
|
90 |
+
' will be created', required=False, type=str)
|
91 |
+
parser.add_argument('-mode', type=str, default='accurate', help='can be either \'fast\' or \'accurate\'. Fast will '
|
92 |
+
'use only one set of parameters whereas accurate will '
|
93 |
+
'use the five sets of parameters that resulted from '
|
94 |
+
'our cross-validation as an ensemble. Default: '
|
95 |
+
'accurate',
|
96 |
+
required=False)
|
97 |
+
parser.add_argument('-device', default='0', type=str, help='used to set on which device the prediction will run. '
|
98 |
+
'Must be either int or str. Use int for GPU id or '
|
99 |
+
'\'cpu\' to run on CPU. When using CPU you should '
|
100 |
+
'consider disabling tta. Default for -device is: 0',
|
101 |
+
required=False)
|
102 |
+
parser.add_argument('-tta', default=1, required=False, type=int, help='whether to use test time data augmentation '
|
103 |
+
'(mirroring). 1= True, 0=False. Disable this '
|
104 |
+
'if you are using CPU to speed things up! '
|
105 |
+
'Default: 1')
|
106 |
+
parser.add_argument('-pp', default=1, type=int, required=False, help='set to 0 to disabe postprocessing (remove all'
|
107 |
+
' but the largest connected component in '
|
108 |
+
'the prediction. Default: 1')
|
109 |
+
parser.add_argument('-s', '--save_mask', default=1, type=int, required=False, help='if set to 0 the segmentation '
|
110 |
+
'mask will not be '
|
111 |
+
'saved')
|
112 |
+
parser.add_argument('--overwrite_existing', default=1, type=int, required=False, help="set this to 0 if you don't "
|
113 |
+
"want to overwrite existing "
|
114 |
+
"predictions")
|
115 |
+
|
116 |
+
args = parser.parse_args()
|
117 |
+
|
118 |
+
hd_bet(args.input,args.output,args.mode,args.device,args.tta,args.pp,args.save_mask,args.overwrite_existing)
|
119 |
+
|
src/BrainIAC/HD_BET/network_architecture.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from HD_BET.utils import softmax_helper
|
5 |
+
|
6 |
+
|
7 |
+
class EncodingModule(nn.Module):
|
8 |
+
def __init__(self, in_channels, out_channels, filter_size=3, dropout_p=0.3, leakiness=1e-2, conv_bias=True,
|
9 |
+
inst_norm_affine=True, lrelu_inplace=True):
|
10 |
+
nn.Module.__init__(self)
|
11 |
+
self.dropout_p = dropout_p
|
12 |
+
self.lrelu_inplace = lrelu_inplace
|
13 |
+
self.inst_norm_affine = inst_norm_affine
|
14 |
+
self.conv_bias = conv_bias
|
15 |
+
self.leakiness = leakiness
|
16 |
+
self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
17 |
+
self.conv1 = nn.Conv3d(in_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias)
|
18 |
+
self.dropout = nn.Dropout3d(dropout_p)
|
19 |
+
self.bn_2 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
20 |
+
self.conv2 = nn.Conv3d(out_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
skip = x
|
24 |
+
x = F.leaky_relu(self.bn_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
25 |
+
x = self.conv1(x)
|
26 |
+
if self.dropout_p is not None and self.dropout_p > 0:
|
27 |
+
x = self.dropout(x)
|
28 |
+
x = F.leaky_relu(self.bn_2(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
29 |
+
x = self.conv2(x)
|
30 |
+
x = x + skip
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
class Upsample(nn.Module):
|
35 |
+
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=True):
|
36 |
+
super(Upsample, self).__init__()
|
37 |
+
self.align_corners = align_corners
|
38 |
+
self.mode = mode
|
39 |
+
self.scale_factor = scale_factor
|
40 |
+
self.size = size
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode,
|
44 |
+
align_corners=self.align_corners)
|
45 |
+
|
46 |
+
|
47 |
+
class LocalizationModule(nn.Module):
|
48 |
+
def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
|
49 |
+
lrelu_inplace=True):
|
50 |
+
nn.Module.__init__(self)
|
51 |
+
self.lrelu_inplace = lrelu_inplace
|
52 |
+
self.inst_norm_affine = inst_norm_affine
|
53 |
+
self.conv_bias = conv_bias
|
54 |
+
self.leakiness = leakiness
|
55 |
+
self.conv1 = nn.Conv3d(in_channels, in_channels, 3, 1, 1, bias=self.conv_bias)
|
56 |
+
self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
57 |
+
self.conv2 = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=self.conv_bias)
|
58 |
+
self.bn_2 = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
x = F.leaky_relu(self.bn_1(self.conv1(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
62 |
+
x = F.leaky_relu(self.bn_2(self.conv2(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
class UpsamplingModule(nn.Module):
|
67 |
+
def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
|
68 |
+
lrelu_inplace=True):
|
69 |
+
nn.Module.__init__(self)
|
70 |
+
self.lrelu_inplace = lrelu_inplace
|
71 |
+
self.inst_norm_affine = inst_norm_affine
|
72 |
+
self.conv_bias = conv_bias
|
73 |
+
self.leakiness = leakiness
|
74 |
+
self.upsample = Upsample(scale_factor=2, mode="trilinear", align_corners=True)
|
75 |
+
self.upsample_conv = nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=self.conv_bias)
|
76 |
+
self.bn = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
x = F.leaky_relu(self.bn(self.upsample_conv(self.upsample(x))), negative_slope=self.leakiness,
|
80 |
+
inplace=self.lrelu_inplace)
|
81 |
+
return x
|
82 |
+
|
83 |
+
|
84 |
+
class DownsamplingModule(nn.Module):
|
85 |
+
def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
|
86 |
+
lrelu_inplace=True):
|
87 |
+
nn.Module.__init__(self)
|
88 |
+
self.lrelu_inplace = lrelu_inplace
|
89 |
+
self.inst_norm_affine = inst_norm_affine
|
90 |
+
self.conv_bias = conv_bias
|
91 |
+
self.leakiness = leakiness
|
92 |
+
self.bn = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
93 |
+
self.downsample = nn.Conv3d(in_channels, out_channels, 3, 2, 1, bias=self.conv_bias)
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
x = F.leaky_relu(self.bn(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
97 |
+
b = self.downsample(x)
|
98 |
+
return x, b
|
99 |
+
|
100 |
+
|
101 |
+
class Network(nn.Module):
|
102 |
+
def __init__(self, num_classes=4, num_input_channels=4, base_filters=16, dropout_p=0.3,
|
103 |
+
final_nonlin=softmax_helper, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
|
104 |
+
lrelu_inplace=True, do_ds=True):
|
105 |
+
super(Network, self).__init__()
|
106 |
+
|
107 |
+
self.do_ds = do_ds
|
108 |
+
self.lrelu_inplace = lrelu_inplace
|
109 |
+
self.inst_norm_affine = inst_norm_affine
|
110 |
+
self.conv_bias = conv_bias
|
111 |
+
self.leakiness = leakiness
|
112 |
+
self.final_nonlin = final_nonlin
|
113 |
+
self.init_conv = nn.Conv3d(num_input_channels, base_filters, 3, 1, 1, bias=self.conv_bias)
|
114 |
+
|
115 |
+
self.context1 = EncodingModule(base_filters, base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
|
116 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
117 |
+
self.down1 = DownsamplingModule(base_filters, base_filters * 2, leakiness=1e-2, conv_bias=True,
|
118 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
119 |
+
|
120 |
+
self.context2 = EncodingModule(2 * base_filters, 2 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
|
121 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
122 |
+
self.down2 = DownsamplingModule(2 * base_filters, base_filters * 4, leakiness=1e-2, conv_bias=True,
|
123 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
124 |
+
|
125 |
+
self.context3 = EncodingModule(4 * base_filters, 4 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
|
126 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
127 |
+
self.down3 = DownsamplingModule(4 * base_filters, base_filters * 8, leakiness=1e-2, conv_bias=True,
|
128 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
129 |
+
|
130 |
+
self.context4 = EncodingModule(8 * base_filters, 8 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
|
131 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
132 |
+
self.down4 = DownsamplingModule(8 * base_filters, base_filters * 16, leakiness=1e-2, conv_bias=True,
|
133 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
134 |
+
|
135 |
+
self.context5 = EncodingModule(16 * base_filters, 16 * base_filters, 3, dropout_p, leakiness=1e-2,
|
136 |
+
conv_bias=True, inst_norm_affine=True, lrelu_inplace=True)
|
137 |
+
|
138 |
+
self.bn_after_context5 = nn.InstanceNorm3d(16 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
|
139 |
+
self.up1 = UpsamplingModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True,
|
140 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
141 |
+
|
142 |
+
self.loc1 = LocalizationModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True,
|
143 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
144 |
+
self.up2 = UpsamplingModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True,
|
145 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
146 |
+
|
147 |
+
self.loc2 = LocalizationModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True,
|
148 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
149 |
+
self.loc2_seg = nn.Conv3d(4 * base_filters, num_classes, 1, 1, 0, bias=False)
|
150 |
+
self.up3 = UpsamplingModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True,
|
151 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
152 |
+
|
153 |
+
self.loc3 = LocalizationModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True,
|
154 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
155 |
+
self.loc3_seg = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False)
|
156 |
+
self.up4 = UpsamplingModule(2 * base_filters, 1 * base_filters, leakiness=1e-2, conv_bias=True,
|
157 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
158 |
+
|
159 |
+
self.end_conv_1 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias)
|
160 |
+
self.end_conv_1_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
|
161 |
+
self.end_conv_2 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias)
|
162 |
+
self.end_conv_2_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
|
163 |
+
self.seg_layer = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False)
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
seg_outputs = []
|
167 |
+
|
168 |
+
x = self.init_conv(x)
|
169 |
+
x = self.context1(x)
|
170 |
+
|
171 |
+
skip1, x = self.down1(x)
|
172 |
+
x = self.context2(x)
|
173 |
+
|
174 |
+
skip2, x = self.down2(x)
|
175 |
+
x = self.context3(x)
|
176 |
+
|
177 |
+
skip3, x = self.down3(x)
|
178 |
+
x = self.context4(x)
|
179 |
+
|
180 |
+
skip4, x = self.down4(x)
|
181 |
+
x = self.context5(x)
|
182 |
+
|
183 |
+
x = F.leaky_relu(self.bn_after_context5(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
184 |
+
x = self.up1(x)
|
185 |
+
|
186 |
+
x = torch.cat((skip4, x), dim=1)
|
187 |
+
x = self.loc1(x)
|
188 |
+
x = self.up2(x)
|
189 |
+
|
190 |
+
x = torch.cat((skip3, x), dim=1)
|
191 |
+
x = self.loc2(x)
|
192 |
+
loc2_seg = self.final_nonlin(self.loc2_seg(x))
|
193 |
+
seg_outputs.append(loc2_seg)
|
194 |
+
x = self.up3(x)
|
195 |
+
|
196 |
+
x = torch.cat((skip2, x), dim=1)
|
197 |
+
x = self.loc3(x)
|
198 |
+
loc3_seg = self.final_nonlin(self.loc3_seg(x))
|
199 |
+
seg_outputs.append(loc3_seg)
|
200 |
+
x = self.up4(x)
|
201 |
+
|
202 |
+
x = torch.cat((skip1, x), dim=1)
|
203 |
+
x = F.leaky_relu(self.end_conv_1_bn(self.end_conv_1(x)), negative_slope=self.leakiness,
|
204 |
+
inplace=self.lrelu_inplace)
|
205 |
+
x = F.leaky_relu(self.end_conv_2_bn(self.end_conv_2(x)), negative_slope=self.leakiness,
|
206 |
+
inplace=self.lrelu_inplace)
|
207 |
+
x = self.final_nonlin(self.seg_layer(x))
|
208 |
+
seg_outputs.append(x)
|
209 |
+
|
210 |
+
if self.do_ds:
|
211 |
+
return seg_outputs[::-1]
|
212 |
+
else:
|
213 |
+
return seg_outputs[-1]
|
src/BrainIAC/HD_BET/paths.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
# please refer to the readme on where to get the parameters. Save them in this folder:
|
4 |
+
# Original Path: "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params"
|
5 |
+
# Updated path for Docker container:
|
6 |
+
folder_with_parameter_files = "/app/BrainIAC/hdbet_model"
|
src/BrainIAC/HD_BET/predict_case.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def pad_patient_3D(patient, shape_must_be_divisible_by=16, min_size=None):
|
6 |
+
if not (isinstance(shape_must_be_divisible_by, list) or isinstance(shape_must_be_divisible_by, tuple)):
|
7 |
+
shape_must_be_divisible_by = [shape_must_be_divisible_by] * 3
|
8 |
+
shp = patient.shape
|
9 |
+
new_shp = [shp[0] + shape_must_be_divisible_by[0] - shp[0] % shape_must_be_divisible_by[0],
|
10 |
+
shp[1] + shape_must_be_divisible_by[1] - shp[1] % shape_must_be_divisible_by[1],
|
11 |
+
shp[2] + shape_must_be_divisible_by[2] - shp[2] % shape_must_be_divisible_by[2]]
|
12 |
+
for i in range(len(shp)):
|
13 |
+
if shp[i] % shape_must_be_divisible_by[i] == 0:
|
14 |
+
new_shp[i] -= shape_must_be_divisible_by[i]
|
15 |
+
if min_size is not None:
|
16 |
+
new_shp = np.max(np.vstack((np.array(new_shp), np.array(min_size))), 0)
|
17 |
+
return reshape_by_padding_upper_coords(patient, new_shp, 0), shp
|
18 |
+
|
19 |
+
|
20 |
+
def reshape_by_padding_upper_coords(image, new_shape, pad_value=None):
|
21 |
+
shape = tuple(list(image.shape))
|
22 |
+
new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2,len(shape))), axis=0))
|
23 |
+
if pad_value is None:
|
24 |
+
if len(shape) == 2:
|
25 |
+
pad_value = image[0,0]
|
26 |
+
elif len(shape) == 3:
|
27 |
+
pad_value = image[0, 0, 0]
|
28 |
+
else:
|
29 |
+
raise ValueError("Image must be either 2 or 3 dimensional")
|
30 |
+
res = np.ones(list(new_shape), dtype=image.dtype) * pad_value
|
31 |
+
if len(shape) == 2:
|
32 |
+
res[0:0+int(shape[0]), 0:0+int(shape[1])] = image
|
33 |
+
elif len(shape) == 3:
|
34 |
+
res[0:0+int(shape[0]), 0:0+int(shape[1]), 0:0+int(shape[2])] = image
|
35 |
+
return res
|
36 |
+
|
37 |
+
|
38 |
+
def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE=None,
|
39 |
+
new_shape_must_be_divisible_by=16, min_size=None, main_device=0, mirror_axes=(2, 3, 4)):
|
40 |
+
with torch.no_grad():
|
41 |
+
pad_res = []
|
42 |
+
for i in range(patient_data.shape[0]):
|
43 |
+
t, old_shape = pad_patient_3D(patient_data[i], new_shape_must_be_divisible_by, min_size)
|
44 |
+
pad_res.append(t[None])
|
45 |
+
|
46 |
+
patient_data = np.vstack(pad_res)
|
47 |
+
|
48 |
+
new_shp = patient_data.shape
|
49 |
+
|
50 |
+
data = np.zeros(tuple([1] + list(new_shp)), dtype=np.float32)
|
51 |
+
|
52 |
+
data[0] = patient_data
|
53 |
+
|
54 |
+
if BATCH_SIZE is not None:
|
55 |
+
data = np.vstack([data] * BATCH_SIZE)
|
56 |
+
|
57 |
+
a = torch.rand(data.shape).float()
|
58 |
+
|
59 |
+
if main_device == 'cpu':
|
60 |
+
pass
|
61 |
+
else:
|
62 |
+
a = a.cuda(main_device)
|
63 |
+
|
64 |
+
if do_mirroring:
|
65 |
+
x = 8
|
66 |
+
else:
|
67 |
+
x = 1
|
68 |
+
all_preds = []
|
69 |
+
for i in range(num_repeats):
|
70 |
+
for m in range(x):
|
71 |
+
data_for_net = np.array(data)
|
72 |
+
do_stuff = False
|
73 |
+
if m == 0:
|
74 |
+
do_stuff = True
|
75 |
+
pass
|
76 |
+
if m == 1 and (4 in mirror_axes):
|
77 |
+
do_stuff = True
|
78 |
+
data_for_net = data_for_net[:, :, :, :, ::-1]
|
79 |
+
if m == 2 and (3 in mirror_axes):
|
80 |
+
do_stuff = True
|
81 |
+
data_for_net = data_for_net[:, :, :, ::-1, :]
|
82 |
+
if m == 3 and (4 in mirror_axes) and (3 in mirror_axes):
|
83 |
+
do_stuff = True
|
84 |
+
data_for_net = data_for_net[:, :, :, ::-1, ::-1]
|
85 |
+
if m == 4 and (2 in mirror_axes):
|
86 |
+
do_stuff = True
|
87 |
+
data_for_net = data_for_net[:, :, ::-1, :, :]
|
88 |
+
if m == 5 and (2 in mirror_axes) and (4 in mirror_axes):
|
89 |
+
do_stuff = True
|
90 |
+
data_for_net = data_for_net[:, :, ::-1, :, ::-1]
|
91 |
+
if m == 6 and (2 in mirror_axes) and (3 in mirror_axes):
|
92 |
+
do_stuff = True
|
93 |
+
data_for_net = data_for_net[:, :, ::-1, ::-1, :]
|
94 |
+
if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes):
|
95 |
+
do_stuff = True
|
96 |
+
data_for_net = data_for_net[:, :, ::-1, ::-1, ::-1]
|
97 |
+
|
98 |
+
if do_stuff:
|
99 |
+
_ = a.data.copy_(torch.from_numpy(np.copy(data_for_net)))
|
100 |
+
p = net(a) # np.copy is necessary because ::-1 creates just a view i think
|
101 |
+
p = p.data.cpu().numpy()
|
102 |
+
|
103 |
+
if m == 0:
|
104 |
+
pass
|
105 |
+
if m == 1 and (4 in mirror_axes):
|
106 |
+
p = p[:, :, :, :, ::-1]
|
107 |
+
if m == 2 and (3 in mirror_axes):
|
108 |
+
p = p[:, :, :, ::-1, :]
|
109 |
+
if m == 3 and (4 in mirror_axes) and (3 in mirror_axes):
|
110 |
+
p = p[:, :, :, ::-1, ::-1]
|
111 |
+
if m == 4 and (2 in mirror_axes):
|
112 |
+
p = p[:, :, ::-1, :, :]
|
113 |
+
if m == 5 and (2 in mirror_axes) and (4 in mirror_axes):
|
114 |
+
p = p[:, :, ::-1, :, ::-1]
|
115 |
+
if m == 6 and (2 in mirror_axes) and (3 in mirror_axes):
|
116 |
+
p = p[:, :, ::-1, ::-1, :]
|
117 |
+
if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes):
|
118 |
+
p = p[:, :, ::-1, ::-1, ::-1]
|
119 |
+
all_preds.append(p)
|
120 |
+
|
121 |
+
stacked = np.vstack(all_preds)[:, :, :old_shape[0], :old_shape[1], :old_shape[2]]
|
122 |
+
predicted_segmentation = stacked.mean(0).argmax(0)
|
123 |
+
uncertainty = stacked.var(0)
|
124 |
+
bayesian_predictions = stacked
|
125 |
+
softmax_pred = stacked.mean(0)
|
126 |
+
return predicted_segmentation, bayesian_predictions, softmax_pred, uncertainty
|
src/BrainIAC/HD_BET/run.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import SimpleITK as sitk
|
4 |
+
from HD_BET.data_loading import load_and_preprocess, save_segmentation_nifti
|
5 |
+
from HD_BET.predict_case import predict_case_3D_net
|
6 |
+
import imp
|
7 |
+
from HD_BET.utils import postprocess_prediction, SetNetworkToVal, get_params_fname, maybe_download_parameters
|
8 |
+
import os
|
9 |
+
import HD_BET
|
10 |
+
|
11 |
+
|
12 |
+
def apply_bet(img, bet, out_fname):
|
13 |
+
img_itk = sitk.ReadImage(img)
|
14 |
+
img_npy = sitk.GetArrayFromImage(img_itk)
|
15 |
+
img_bet = sitk.GetArrayFromImage(sitk.ReadImage(bet))
|
16 |
+
img_npy[img_bet == 0] = 0
|
17 |
+
out = sitk.GetImageFromArray(img_npy)
|
18 |
+
out.CopyInformation(img_itk)
|
19 |
+
sitk.WriteImage(out, out_fname)
|
20 |
+
|
21 |
+
|
22 |
+
def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.join(HD_BET.__path__[0], "config.py"), device=0,
|
23 |
+
postprocess=False, do_tta=True, keep_mask=True, overwrite=True):
|
24 |
+
"""
|
25 |
+
|
26 |
+
:param mri_fnames: str or list/tuple of str
|
27 |
+
:param output_fnames: str or list/tuple of str. If list: must have the same length as output_fnames
|
28 |
+
:param mode: fast or accurate
|
29 |
+
:param config_file: config.py
|
30 |
+
:param device: either int (for device id) or 'cpu'
|
31 |
+
:param postprocess: whether to do postprocessing or not. Postprocessing here consists of simply discarding all
|
32 |
+
but the largest predicted connected component. Default False
|
33 |
+
:param do_tta: whether to do test time data augmentation by mirroring along all axes. Default: True. If you use
|
34 |
+
CPU you may want to turn that off to speed things up
|
35 |
+
:return:
|
36 |
+
"""
|
37 |
+
|
38 |
+
list_of_param_files = []
|
39 |
+
|
40 |
+
if mode == 'fast':
|
41 |
+
params_file = get_params_fname(0)
|
42 |
+
maybe_download_parameters(0)
|
43 |
+
|
44 |
+
list_of_param_files.append(params_file)
|
45 |
+
elif mode == 'accurate':
|
46 |
+
for i in range(5):
|
47 |
+
params_file = get_params_fname(i)
|
48 |
+
maybe_download_parameters(i)
|
49 |
+
|
50 |
+
list_of_param_files.append(params_file)
|
51 |
+
else:
|
52 |
+
raise ValueError("Unknown value for mode: %s. Expected: fast or accurate" % mode)
|
53 |
+
|
54 |
+
assert all([os.path.isfile(i) for i in list_of_param_files]), "Could not find parameter files"
|
55 |
+
|
56 |
+
cf = imp.load_source('cf', config_file)
|
57 |
+
cf = cf.config()
|
58 |
+
|
59 |
+
net, _ = cf.get_network(cf.val_use_train_mode, None)
|
60 |
+
if device == "cpu":
|
61 |
+
net = net.cpu()
|
62 |
+
else:
|
63 |
+
net.cuda(device)
|
64 |
+
|
65 |
+
if not isinstance(mri_fnames, (list, tuple)):
|
66 |
+
mri_fnames = [mri_fnames]
|
67 |
+
|
68 |
+
if not isinstance(output_fnames, (list, tuple)):
|
69 |
+
output_fnames = [output_fnames]
|
70 |
+
|
71 |
+
assert len(mri_fnames) == len(output_fnames), "mri_fnames and output_fnames must have the same length"
|
72 |
+
|
73 |
+
params = []
|
74 |
+
for p in list_of_param_files:
|
75 |
+
params.append(torch.load(p, map_location=lambda storage, loc: storage))
|
76 |
+
|
77 |
+
for in_fname, out_fname in zip(mri_fnames, output_fnames):
|
78 |
+
mask_fname = out_fname[:-7] + "_mask.nii.gz"
|
79 |
+
if overwrite or (not (os.path.isfile(mask_fname) and keep_mask) or not os.path.isfile(out_fname)):
|
80 |
+
print("File:", in_fname)
|
81 |
+
print("preprocessing...")
|
82 |
+
try:
|
83 |
+
data, data_dict = load_and_preprocess(in_fname)
|
84 |
+
except RuntimeError:
|
85 |
+
print("\nERROR\nCould not read file", in_fname, "\n")
|
86 |
+
continue
|
87 |
+
except AssertionError as e:
|
88 |
+
print(e)
|
89 |
+
continue
|
90 |
+
|
91 |
+
softmax_preds = []
|
92 |
+
|
93 |
+
print("prediction (CNN id)...")
|
94 |
+
for i, p in enumerate(params):
|
95 |
+
print(i)
|
96 |
+
net.load_state_dict(p)
|
97 |
+
net.eval()
|
98 |
+
net.apply(SetNetworkToVal(False, False))
|
99 |
+
_, _, softmax_pred, _ = predict_case_3D_net(net, data, do_tta, cf.val_num_repeats,
|
100 |
+
cf.val_batch_size, cf.net_input_must_be_divisible_by,
|
101 |
+
cf.val_min_size, device, cf.da_mirror_axes)
|
102 |
+
softmax_preds.append(softmax_pred[None])
|
103 |
+
|
104 |
+
seg = np.argmax(np.vstack(softmax_preds).mean(0), 0)
|
105 |
+
|
106 |
+
if postprocess:
|
107 |
+
seg = postprocess_prediction(seg)
|
108 |
+
|
109 |
+
print("exporting segmentation...")
|
110 |
+
save_segmentation_nifti(seg, data_dict, mask_fname)
|
111 |
+
|
112 |
+
apply_bet(in_fname, mask_fname, out_fname)
|
113 |
+
|
114 |
+
if not keep_mask:
|
115 |
+
os.remove(mask_fname)
|
116 |
+
|
117 |
+
|
src/BrainIAC/HD_BET/utils.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from urllib.request import urlopen
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
import numpy as np
|
5 |
+
from skimage.morphology import label
|
6 |
+
import os
|
7 |
+
from HD_BET.paths import folder_with_parameter_files
|
8 |
+
|
9 |
+
|
10 |
+
def get_params_fname(fold):
|
11 |
+
return os.path.join(folder_with_parameter_files, "%d.model" % fold)
|
12 |
+
|
13 |
+
|
14 |
+
def maybe_download_parameters(fold=0, force_overwrite=False):
|
15 |
+
"""
|
16 |
+
Downloads the parameters for some fold if it is not present yet.
|
17 |
+
:param fold:
|
18 |
+
:param force_overwrite: if True the old parameter file will be deleted (if present) prior to download
|
19 |
+
:return:
|
20 |
+
"""
|
21 |
+
|
22 |
+
assert 0 <= fold <= 4, "fold must be between 0 and 4"
|
23 |
+
|
24 |
+
if not os.path.isdir(folder_with_parameter_files):
|
25 |
+
maybe_mkdir_p(folder_with_parameter_files)
|
26 |
+
|
27 |
+
out_filename = get_params_fname(fold)
|
28 |
+
|
29 |
+
if force_overwrite and os.path.isfile(out_filename):
|
30 |
+
os.remove(out_filename)
|
31 |
+
|
32 |
+
if not os.path.isfile(out_filename):
|
33 |
+
url = "https://zenodo.org/record/2540695/files/%d.model?download=1" % fold
|
34 |
+
print("Downloading", url, "...")
|
35 |
+
data = urlopen(url).read()
|
36 |
+
#out_filename = "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params/0.model"
|
37 |
+
with open(out_filename, 'wb') as f:
|
38 |
+
f.write(data)
|
39 |
+
|
40 |
+
|
41 |
+
def init_weights(module):
|
42 |
+
if isinstance(module, nn.Conv3d):
|
43 |
+
module.weight = nn.init.kaiming_normal(module.weight, a=1e-2)
|
44 |
+
if module.bias is not None:
|
45 |
+
module.bias = nn.init.constant(module.bias, 0)
|
46 |
+
|
47 |
+
|
48 |
+
def softmax_helper(x):
|
49 |
+
rpt = [1 for _ in range(len(x.size()))]
|
50 |
+
rpt[1] = x.size(1)
|
51 |
+
x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
|
52 |
+
e_x = torch.exp(x - x_max)
|
53 |
+
return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
|
54 |
+
|
55 |
+
|
56 |
+
class SetNetworkToVal(object):
|
57 |
+
def __init__(self, use_dropout_sampling=False, norm_use_average=True):
|
58 |
+
self.norm_use_average = norm_use_average
|
59 |
+
self.use_dropout_sampling = use_dropout_sampling
|
60 |
+
|
61 |
+
def __call__(self, module):
|
62 |
+
if isinstance(module, nn.Dropout3d) or isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout):
|
63 |
+
module.train(self.use_dropout_sampling)
|
64 |
+
elif isinstance(module, nn.InstanceNorm3d) or isinstance(module, nn.InstanceNorm2d) or \
|
65 |
+
isinstance(module, nn.InstanceNorm1d) \
|
66 |
+
or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or \
|
67 |
+
isinstance(module, nn.BatchNorm1d):
|
68 |
+
module.train(not self.norm_use_average)
|
69 |
+
|
70 |
+
|
71 |
+
def postprocess_prediction(seg):
|
72 |
+
# basically look for connected components and choose the largest one, delete everything else
|
73 |
+
print("running postprocessing... ")
|
74 |
+
mask = seg != 0
|
75 |
+
lbls = label(mask, connectivity=mask.ndim)
|
76 |
+
lbls_sizes = [np.sum(lbls == i) for i in np.unique(lbls)]
|
77 |
+
largest_region = np.argmax(lbls_sizes[1:]) + 1
|
78 |
+
seg[lbls != largest_region] = 0
|
79 |
+
return seg
|
80 |
+
|
81 |
+
|
82 |
+
def subdirs(folder, join=True, prefix=None, suffix=None, sort=True):
|
83 |
+
if join:
|
84 |
+
l = os.path.join
|
85 |
+
else:
|
86 |
+
l = lambda x, y: y
|
87 |
+
res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))
|
88 |
+
and (prefix is None or i.startswith(prefix))
|
89 |
+
and (suffix is None or i.endswith(suffix))]
|
90 |
+
if sort:
|
91 |
+
res.sort()
|
92 |
+
return res
|
93 |
+
|
94 |
+
|
95 |
+
def subfiles(folder, join=True, prefix=None, suffix=None, sort=True):
|
96 |
+
if join:
|
97 |
+
l = os.path.join
|
98 |
+
else:
|
99 |
+
l = lambda x, y: y
|
100 |
+
res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i))
|
101 |
+
and (prefix is None or i.startswith(prefix))
|
102 |
+
and (suffix is None or i.endswith(suffix))]
|
103 |
+
if sort:
|
104 |
+
res.sort()
|
105 |
+
return res
|
106 |
+
|
107 |
+
|
108 |
+
subfolders = subdirs # I am tired of confusing those
|
109 |
+
|
110 |
+
|
111 |
+
def maybe_mkdir_p(directory):
|
112 |
+
splits = directory.split("/")[1:]
|
113 |
+
for i in range(0, len(splits)):
|
114 |
+
if not os.path.isdir(os.path.join("", *splits[:i+1])):
|
115 |
+
os.mkdir(os.path.join("", *splits[:i+1]))
|
src/BrainIAC/IDHprediction/README.md
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# IDH Mutation Classification
|
2 |
+
|
3 |
+
<p align="left">
|
4 |
+
<img src="idh.jpeg" width="200" alt="IDH Mutation Classification Example"/>
|
5 |
+
</p>
|
6 |
+
|
7 |
+
## Overview
|
8 |
+
|
9 |
+
We present the IDH mutation classification training and inference code for BrainIAC as a downstream task. The pipeline is trained and infered on T1CE and FLAIR scans, with AUC and F1 as evaluation metric.
|
10 |
+
|
11 |
+
## Data Requirements
|
12 |
+
|
13 |
+
- **Input**: T1CE and FLAIR MR sequences from a single scan
|
14 |
+
- **Format**: NIFTI (.nii.gz)
|
15 |
+
- **Preprocessing**: Bias field corrected, registered to standard space, skull stripped
|
16 |
+
- **CSV Structure**:
|
17 |
+
```
|
18 |
+
pat_id,scandate,label
|
19 |
+
subject001,scan_sequence,1 # 1 for IDH mutant, 0 for wildtype
|
20 |
+
```
|
21 |
+
refer to [ quickstart.ipynb](../quickstart.ipynb) to find how to preprocess data and generate csv file.
|
22 |
+
|
23 |
+
## Setup
|
24 |
+
|
25 |
+
1. **Configuration**:
|
26 |
+
change the [config.yml](../config.yml) file accordingly.
|
27 |
+
```yaml
|
28 |
+
# config.yml
|
29 |
+
data:
|
30 |
+
train_csv: "path/to/train.csv"
|
31 |
+
val_csv: "path/to/val.csv"
|
32 |
+
test_csv: "path/to/test.csv"
|
33 |
+
root_dir: "../data/sample/processed"
|
34 |
+
collate: 2 # two sequence pipeline
|
35 |
+
|
36 |
+
checkpoints: "./checkpoints/idh_model.00" # for inference/testing
|
37 |
+
|
38 |
+
train:
|
39 |
+
finetune: 'yes' # yes to finetune the entire model
|
40 |
+
freeze: 'no' # yes to freeze the resnet backbone
|
41 |
+
weights: ./checkpoints/brainiac.ckpt # path to brainiac weights
|
42 |
+
```
|
43 |
+
|
44 |
+
2. **Training**:
|
45 |
+
```bash
|
46 |
+
python -m IDHprediction.train_idh
|
47 |
+
```
|
48 |
+
|
49 |
+
3. **Inference**:
|
50 |
+
```bash
|
51 |
+
python -m IDHprediction.infer_idh
|
52 |
+
```
|
53 |
+
|
src/BrainIAC/IDHprediction/__init__.py
ADDED
File without changes
|
src/BrainIAC/IDHprediction/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (150 Bytes). View file
|
|
src/BrainIAC/IDHprediction/__pycache__/infer_idh.cpython-39.pyc
ADDED
Binary file (4.82 kB). View file
|
|
src/BrainIAC/IDHprediction/idh.jpeg
ADDED
![]() |