Divyanshu Tak commited on
Commit
f5288df
·
1 Parent(s): d85b08e

Initial commit of BrainIAC Docker application

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +9 -0
  3. Dockerfile +49 -0
  4. README.md +15 -7
  5. requirements.txt +23 -0
  6. src/.DS_Store +0 -0
  7. src/BrainIAC/.DS_Store +0 -0
  8. src/BrainIAC/Brainage/README.md +55 -0
  9. src/BrainIAC/Brainage/__init__.py +0 -0
  10. src/BrainIAC/Brainage/__pycache__/__init__.cpython-39.pyc +0 -0
  11. src/BrainIAC/Brainage/__pycache__/infer_brainage.cpython-39.pyc +0 -0
  12. src/BrainIAC/Brainage/brainage.jpeg +0 -0
  13. src/BrainIAC/Brainage/infer_brainage.py +85 -0
  14. src/BrainIAC/Brainage/train_brainage.py +230 -0
  15. src/BrainIAC/HD_BET/__pycache__/config.cpython-310.pyc +0 -0
  16. src/BrainIAC/HD_BET/__pycache__/config.cpython-38.pyc +0 -0
  17. src/BrainIAC/HD_BET/__pycache__/config.cpython-39.pyc +0 -0
  18. src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-310.pyc +0 -0
  19. src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-38.pyc +0 -0
  20. src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-39.pyc +0 -0
  21. src/BrainIAC/HD_BET/__pycache__/hd_bet.cpython-310.pyc +0 -0
  22. src/BrainIAC/HD_BET/__pycache__/hd_bet.cpython-38.pyc +0 -0
  23. src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-310.pyc +0 -0
  24. src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-38.pyc +0 -0
  25. src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-39.pyc +0 -0
  26. src/BrainIAC/HD_BET/__pycache__/paths.cpython-310.pyc +0 -0
  27. src/BrainIAC/HD_BET/__pycache__/paths.cpython-38.pyc +0 -0
  28. src/BrainIAC/HD_BET/__pycache__/paths.cpython-39.pyc +0 -0
  29. src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-310.pyc +0 -0
  30. src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-38.pyc +0 -0
  31. src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-39.pyc +0 -0
  32. src/BrainIAC/HD_BET/__pycache__/run.cpython-310.pyc +0 -0
  33. src/BrainIAC/HD_BET/__pycache__/run.cpython-38.pyc +0 -0
  34. src/BrainIAC/HD_BET/__pycache__/run.cpython-39.pyc +0 -0
  35. src/BrainIAC/HD_BET/__pycache__/utils.cpython-310.pyc +0 -0
  36. src/BrainIAC/HD_BET/__pycache__/utils.cpython-38.pyc +0 -0
  37. src/BrainIAC/HD_BET/__pycache__/utils.cpython-39.pyc +0 -0
  38. src/BrainIAC/HD_BET/config.py +121 -0
  39. src/BrainIAC/HD_BET/data_loading.py +121 -0
  40. src/BrainIAC/HD_BET/hd_bet.py +119 -0
  41. src/BrainIAC/HD_BET/network_architecture.py +213 -0
  42. src/BrainIAC/HD_BET/paths.py +6 -0
  43. src/BrainIAC/HD_BET/predict_case.py +126 -0
  44. src/BrainIAC/HD_BET/run.py +117 -0
  45. src/BrainIAC/HD_BET/utils.py +115 -0
  46. src/BrainIAC/IDHprediction/README.md +53 -0
  47. src/BrainIAC/IDHprediction/__init__.py +0 -0
  48. src/BrainIAC/IDHprediction/__pycache__/__init__.cpython-39.pyc +0 -0
  49. src/BrainIAC/IDHprediction/__pycache__/infer_idh.cpython-39.pyc +0 -0
  50. src/BrainIAC/IDHprediction/idh.jpeg +0 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ src/BrainIAC/checkpoints/*.pt filter=lfs diff=lfs merge=lfs -text
37
+ src/BrainIAC/hdbet_model/**/*.pth filter=lfs diff=lfs merge=lfs -text
38
+ src/BrainIAC/golden_image/**/*.nii.gz filter=lfs diff=lfs merge=lfs -text
39
+ src/BrainIAC/golden_image/**/*.nii filter=lfs diff=lfs merge=lfs -text
40
+ src/BrainIAC/preprocessing/atlases/*.nii filter=lfs diff=lfs merge=lfs -text
41
+ src/BrainIAC/golden_image/mni_templates/*.nii filter=lfs diff=lfs merge=lfs -text
42
+ src/BrainIAC/hdbet_model/**/*.onnx filter=lfs diff=lfs merge=lfs -text
43
+ src/BrainIAC/golden_image/*.nii filter=lfs diff=lfs merge=lfs -text
44
+ /Users/divyanshutak/spaces/BrainIAC-Brainage-V0/src/BrainIAC/static/images/*.jpeg filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.10-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /app
6
+
7
+ # Install necessary system dependencies
8
+ RUN apt-get update && \
9
+ apt-get install -y --no-install-recommends \
10
+ unzip \
11
+ git \
12
+ # Add potential ITK/build dependencies (might need adjustment)
13
+ build-essential \
14
+ cmake \
15
+ libgl1 \
16
+ libglib2.0-0 \
17
+ # libitk5-dev # Removed: Rely on pip install of itk-elastix
18
+ && \
19
+ rm -rf /var/lib/apt/lists/*
20
+
21
+ # Copy the requirements file first to leverage Docker cache
22
+ COPY requirements.txt .
23
+
24
+ # Install Python packages specified in requirements.txt
25
+ # Using --no-cache-dir can reduce image size
26
+ RUN pip install --no-cache-dir -r requirements.txt
27
+
28
+ # --- HD-BET is now copied locally via src/BrainIAC ---
29
+
30
+ # Copy the rest of the application code (including local HD_BET)
31
+ COPY src/BrainIAC /app/BrainIAC
32
+
33
+ # Copy static files (like images)
34
+ COPY src/BrainIAC/static /app/BrainIAC/static
35
+
36
+ # Copy the MNI templates and parameter files
37
+ COPY src/BrainIAC/golden_image /app/BrainIAC/golden_image
38
+
39
+ # Copy the HD-BET models
40
+ COPY src/BrainIAC/hdbet_model /app/BrainIAC/hdbet_model
41
+
42
+ # Copy the model checkpoint
43
+ COPY src/BrainIAC/checkpoints/brainage_model_latest.pt /app/BrainIAC/checkpoints/brainage_model_latest.pt
44
+
45
+ # Make port 5000 available
46
+ EXPOSE 5000
47
+
48
+ # Run app.py when the container launches using gunicorn
49
+ CMD ["gunicorn", "--chdir", "/app/BrainIAC", "--bind", "0.0.0.0:5000", "--timeout", "600", "app:app"]
README.md CHANGED
@@ -1,12 +1,20 @@
1
  ---
2
- title: BrainIAC Brainage V0
3
- emoji: 👁
4
- colorFrom: pink
5
  colorTo: green
6
  sdk: docker
7
- pinned: false
8
- license: cc-by-nc-sa-2.0
9
- short_description: 'Brainage predictor '
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: BrainIAC Brain Age Prediction
3
+ emoji: 🧠
4
+ colorFrom: blue
5
  colorTo: green
6
  sdk: docker
7
+ app_port: 5000 # Make sure this matches the EXPOSE port in Dockerfile and gunicorn bind port
 
 
8
  ---
9
 
10
+ # BrainIAC: Brain Age Prediction Demo
11
+
12
+ This Hugging Face Space hosts an interactive demo for the BrainIAC model, predicting brain age from MRI scans.
13
+
14
+ **Features:**
15
+ - Upload NIfTI (.nii.gz) or DICOM (.zip) files.
16
+ - Optional preprocessing pipeline (Registration, Enhancement, Skull Stripping).
17
+ - Optional generation of saliency map visualizations.
18
+ - Predicts brain age in years based on the input scan.
19
+
20
+ The application runs inside a Docker container defined by the `Dockerfile`. It uses Flask, MONAI, HD-BET, and ITK/Elastix.
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pytorch-lightning==2.3.3
2
+ monai==1.3.2
3
+ nibabel==5.2.1
4
+ scikit-image==0.21.0
5
+ scikit-learn==1.2.2
6
+ scipy==1.10.1
7
+ seaborn==0.12.2
8
+ numpy==1.23.5
9
+ autograd==1.7.0
10
+ matplotlib==3.7.1
11
+ SimpleITK==2.4.0
12
+ tqdm
13
+ pydicom
14
+ wandb
15
+ lifelines
16
+ torch==2.6.0
17
+ opencv-python
18
+ pandas
19
+ Flask
20
+ gunicorn
21
+ PyYAML
22
+ dicom2nifti
23
+ itk-elastix
src/.DS_Store ADDED
Binary file (8.2 kB). View file
 
src/BrainIAC/.DS_Store ADDED
Binary file (12.3 kB). View file
 
src/BrainIAC/Brainage/README.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Brain Age Prediction
2
+
3
+ <p align="left">
4
+ <img src="brainage.jpeg" width="200" alt="Brain Age Prediction Example"/>
5
+ </p>
6
+
7
+ ## Overview
8
+
9
+ We present the brainage prediction training and inference code for BrainIAC as a downstream task. The pipeline is trained and infered on T1 scans, with MAE as evaluation metric.
10
+
11
+ ## Data Requirements
12
+
13
+ - **Input**: T1-weighted MRI scans
14
+ - **Format**: NIFTI (.nii.gz)
15
+ - **Preprocessing**: Bias field corrected, registered to standard space, skull stripped
16
+ - **CSV Structure**:
17
+ ```
18
+ pat_id,scandate,label
19
+ subject001,20240101,65 # brain age in years
20
+ ```
21
+ refer to [ quickstart.ipynb](../quickstart.ipynb) to find how to preprocess data and generate csv file.
22
+
23
+
24
+ ## Setup
25
+
26
+ 1. **Configuration**:
27
+ change the [config.yml](../config.yml) file accordingly.
28
+ ```yaml
29
+ # config.yml
30
+ data:
31
+ train_csv: "path/to/train.csv"
32
+ val_csv: "path/to/val.csv"
33
+ test_csv: "path/to/test.csv"
34
+ root_dir: "../data/sample/processed"
35
+ collate: 1 # single scan framework
36
+
37
+ checkpoints: "./checkpoints/brainage_model.00" # for inference/testing
38
+
39
+ train:
40
+ finetune: 'yes' # yes to finetune the entire model
41
+ freeze: 'no' # yes to freeze the resnet backbone
42
+ weights: ./checkpoints/brainiac.ckpt # path to brainiac weights
43
+
44
+ ```
45
+
46
+ 2. **Training**:
47
+ ```bash
48
+ python -m Brainage.train_brainage
49
+ ```
50
+
51
+ 3. **Inference**:
52
+ ```bash
53
+ python -m Brainage.infer_brainage
54
+ ```
55
+
src/BrainIAC/Brainage/__init__.py ADDED
File without changes
src/BrainIAC/Brainage/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (145 Bytes). View file
 
src/BrainIAC/Brainage/__pycache__/infer_brainage.cpython-39.pyc ADDED
Binary file (3 kB). View file
 
src/BrainIAC/Brainage/brainage.jpeg ADDED
src/BrainIAC/Brainage/infer_brainage.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+ import os
4
+ from tqdm import tqdm
5
+ from torch.utils.data import DataLoader
6
+ from torch.cuda.amp import autocast
7
+ from sklearn.metrics import mean_absolute_error
8
+ import sys
9
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10
+ from dataset2 import MedicalImageDatasetBalancedIntensity3D
11
+ from model import Backbone, SingleScanModel, Classifier
12
+ from utils import BaseConfig
13
+
14
+ class BrainAgeInference(BaseConfig):
15
+ """
16
+ Inference class for brain age prediction model.
17
+ """
18
+
19
+ def __init__(self):
20
+ """Initialize the inference setup with model and data."""
21
+ super().__init__()
22
+ self.setup_model()
23
+ self.setup_data()
24
+
25
+ def setup_model(self):
26
+ config = self.get_config()
27
+ self.backbone = Backbone()
28
+ self.classifier = Classifier(d_model=2048)
29
+ self.model = SingleScanModel(self.backbone, self.classifier)
30
+
31
+ # Load weights
32
+ checkpoint = torch.load(config["infer"]["checkpoints"], map_location=self.device)
33
+ self.model.load_state_dict(checkpoint["model_state_dict"])
34
+ self.model = self.model.to(self.device)
35
+ self.model.eval()
36
+ print("Model and checkpoint loaded!")
37
+
38
+ ## spinup dataloaders
39
+ def setup_data(self):
40
+ config = self.get_config()
41
+ self.test_dataset = MedicalImageDatasetBalancedIntensity3D(
42
+ csv_path=config["data"]["test_csv"],
43
+ root_dir=config["data"]["root_dir"]
44
+ )
45
+ self.test_loader = DataLoader(
46
+ self.test_dataset,
47
+ batch_size=1,
48
+ shuffle=False,
49
+ collate_fn=self.custom_collate,
50
+ num_workers=1
51
+ )
52
+
53
+ def infer(self):
54
+ """ Infer pass """
55
+ results_df = pd.DataFrame(columns=['PredictedAge', 'TrueAge'])
56
+ all_labels = []
57
+ all_predictions = []
58
+
59
+ with torch.no_grad():
60
+ for sample in tqdm(self.test_loader, desc="Inference", unit="batch"):
61
+ inputs = sample['image'].to(self.device)
62
+ labels = sample['label'].float().to(self.device)
63
+
64
+ with autocast():
65
+ outputs = self.model(inputs)
66
+
67
+ predictions = outputs.cpu().numpy().flatten()
68
+ all_labels.extend(labels.cpu().numpy().flatten())
69
+ all_predictions.extend(predictions)
70
+
71
+ result = pd.DataFrame({
72
+ 'PredictedAge': predictions,
73
+ 'TrueAge': labels.cpu().numpy().flatten()
74
+ })
75
+ results_df = pd.concat([results_df, result], ignore_index=True)
76
+
77
+ mae = mean_absolute_error(all_labels, all_predictions)
78
+ print(f"Mean Absolute Error (MAE): {mae:.4f} months")
79
+ results_df.to_csv('./data/output/brainage_output.csv', index=False)
80
+
81
+ return mae
82
+
83
+ if __name__ == "__main__":
84
+ inferencer = BrainAgeInference()
85
+ mae = inferencer.infer()
src/BrainIAC/Brainage/train_brainage.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader
5
+ import wandb
6
+ from tqdm import tqdm
7
+ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
8
+ from torch.cuda.amp import GradScaler, autocast
9
+ from sklearn.metrics import mean_absolute_error
10
+ import os
11
+ import sys
12
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13
+ from dataset2 import MedicalImageDatasetBalancedIntensity3D, TransformationMedicalImageDatasetBalancedIntensity3D
14
+ from model import Backbone, SingleScanModel, Classifier
15
+ from utils import BaseConfig
16
+
17
+
18
+ class BrainAgeTrainer(BaseConfig):
19
+ """
20
+ A trainer class for brain age prediction models.
21
+
22
+ This class handles the complete training pipeline including model setup,
23
+ data loading, training loop, and validation.
24
+ Inherits from BaseConfig for configuration management.
25
+ """
26
+
27
+ def __init__(self):
28
+ """Initialize the trainer with model, data, and training setup."""
29
+ super().__init__()
30
+ self.setup_wandb()
31
+ self.setup_model()
32
+ self.setup_data()
33
+ self.setup_training()
34
+
35
+ ## setup wandb logger
36
+ def setup_wandb(self):
37
+ config = self.get_config()
38
+ wandb.init(
39
+ project=config['logger']['project_name'],
40
+ name=config['logger']['run_name'],
41
+ config=config
42
+ )
43
+
44
+ def setup_model(self):
45
+ """
46
+ Set up the model architecture.
47
+
48
+ Initializes the backbone and classifier blocks, and loads
49
+ checkpoints
50
+ """
51
+ self.backbone = Backbone()
52
+ self.classifier = Classifier(d_model=2048)
53
+ self.model = SingleScanModel(self.backbone, self.classifier)
54
+
55
+ # Load BrainIACs weights
56
+ config = self.get_config()
57
+ if config["train"]["finetune"] == "yes":
58
+ checkpoint = torch.load(config["train"]["weights"], map_location=self.device)
59
+ state_dict = checkpoint["state_dict"]
60
+ filtered_state_dict = {}
61
+ for key, value in state_dict.items():
62
+ new_key = key.replace("module.", "backbone.") if key.startswith("module.") else key
63
+ filtered_state_dict[new_key] = value
64
+ self.model.backbone.load_state_dict(filtered_state_dict, strict=False)
65
+ print("Pretrained weights loaded!")
66
+
67
+ # Freeze backbone if specified
68
+ if config["train"]["freeze"] == "yes":
69
+ for param in self.model.backbone.parameters():
70
+ param.requires_grad = False
71
+ print("Backbone weights frozen!")
72
+
73
+ self.model = self.model.to(self.device)
74
+
75
+ def setup_data(self):
76
+ """
77
+ Set up data loaders for training and validation.
78
+ Inherit configuration from the base config
79
+ """
80
+ config = self.get_config()
81
+ self.train_dataset = TransformationMedicalImageDatasetBalancedIntensity3D(
82
+ csv_path=config['data']['train_csv'],
83
+ root_dir=config["data"]["root_dir"]
84
+ )
85
+ self.val_dataset = MedicalImageDatasetBalancedIntensity3D(
86
+ csv_path=config['data']['val_csv'],
87
+ root_dir=config["data"]["root_dir"]
88
+ )
89
+
90
+ self.train_loader = DataLoader(
91
+ self.train_dataset,
92
+ batch_size=config["data"]["batch_size"],
93
+ shuffle=True,
94
+ collate_fn=self.custom_collate,
95
+ num_workers=config["data"]["num_workers"]
96
+ )
97
+ self.val_loader = DataLoader(
98
+ self.val_dataset,
99
+ batch_size=1,
100
+ shuffle=False,
101
+ collate_fn=self.custom_collate,
102
+ num_workers=1
103
+ )
104
+
105
+ def setup_training(self):
106
+ """
107
+ Set up training config with loss, scheduler, optimizer.
108
+ """
109
+ config = self.get_config()
110
+ self.criterion = nn.MSELoss()
111
+ self.optimizer = optim.Adam(
112
+ self.model.parameters(),
113
+ lr=config['optim']['lr'],
114
+ weight_decay=config["optim"]["weight_decay"]
115
+ )
116
+ self.scheduler = CosineAnnealingWarmRestarts(self.optimizer, T_0=50, T_mult=2)
117
+ self.scaler = GradScaler()
118
+
119
+ def train(self):
120
+ """
121
+ main training loop
122
+ """
123
+ config = self.get_config()
124
+ max_epochs = config['optim']['max_epochs']
125
+ best_val_loss = float('inf')
126
+ best_val_mae = float('inf')
127
+
128
+ for epoch in range(max_epochs):
129
+ train_loss = self.train_epoch(epoch, max_epochs)
130
+ val_loss, mae = self.validate_epoch(epoch, max_epochs)
131
+
132
+ # Save best model
133
+ if (val_loss <= best_val_loss) and (mae <= best_val_mae):
134
+ print(f"Improved Val Loss from {best_val_loss:.4f} to {val_loss:.4f}")
135
+ print(f"Improved Val MAE from {best_val_mae:.4f} to {mae:.4f}")
136
+ best_val_loss = val_loss
137
+ best_val_mae = mae
138
+ self.save_checkpoint(epoch, val_loss, mae)
139
+
140
+ wandb.finish()
141
+
142
+ def train_epoch(self, epoch, max_epochs):
143
+ """
144
+ Train pass.
145
+
146
+ Args:
147
+ epoch (int): Current epoch number
148
+ max_epochs (int): Total number of epochs
149
+
150
+ Returns:
151
+ float: Average training loss for the epoch
152
+ """
153
+ self.model.train()
154
+ train_loss = 0.0
155
+
156
+ for sample in tqdm(self.train_loader, desc=f"Training Epoch {epoch}/{max_epochs-1}"):
157
+ inputs = sample['image'].to(self.device)
158
+ labels = sample['label'].float().to(self.device)
159
+
160
+ self.optimizer.zero_grad()
161
+ with autocast():
162
+ outputs = self.model(inputs)
163
+ loss = self.criterion(outputs, labels.unsqueeze(1))
164
+
165
+ self.scaler.scale(loss).backward()
166
+ self.scaler.step(self.optimizer)
167
+ self.scaler.update()
168
+
169
+ train_loss += loss.item() * inputs.size(0)
170
+
171
+ train_loss = train_loss / len(self.train_loader.dataset)
172
+ wandb.log({"Train Loss": train_loss})
173
+ return train_loss
174
+
175
+ def validate_epoch(self, epoch, max_epochs):
176
+ """
177
+ Validation pass.
178
+
179
+ Args:
180
+ epoch (int): Current epoch number
181
+ max_epochs (int): Total number of epochs
182
+
183
+ Returns:
184
+ tuple: (validation_loss, mean_absolute_error)
185
+ """
186
+ self.model.eval()
187
+ val_loss = 0.0
188
+ all_labels = []
189
+ all_preds = []
190
+
191
+ with torch.no_grad():
192
+ for sample in tqdm(self.val_loader, desc=f"Validation Epoch {epoch}/{max_epochs-1}"):
193
+ inputs = sample['image'].to(self.device)
194
+ labels = sample['label'].float().to(self.device)
195
+
196
+ outputs = self.model(inputs)
197
+ loss = self.criterion(outputs, labels.unsqueeze(1))
198
+
199
+ val_loss += loss.item() * inputs.size(0)
200
+ all_labels.extend(labels.cpu().numpy().flatten())
201
+ all_preds.extend(outputs.cpu().numpy().flatten())
202
+
203
+ val_loss = val_loss / len(self.val_loader.dataset)
204
+ mae = mean_absolute_error(all_labels, all_preds)
205
+
206
+ wandb.log({"Val Loss": val_loss, "MAE": mae})
207
+ self.scheduler.step(val_loss)
208
+
209
+ print(f"Epoch {epoch}/{max_epochs-1} Val Loss: {val_loss:.4f} MAE: {mae:.4f}")
210
+ return val_loss, mae
211
+
212
+ def save_checkpoint(self, epoch, loss, mae):
213
+ """
214
+ Save model checkpoint.
215
+ """
216
+ config = self.get_config()
217
+ checkpoint = {
218
+ 'model_state_dict': self.model.state_dict(),
219
+ 'loss': loss,
220
+ 'epoch': epoch,
221
+ }
222
+ save_path = os.path.join(
223
+ config['logger']['save_dir'],
224
+ config['logger']['save_name'].format(epoch=epoch, loss=loss, metric=mae)
225
+ )
226
+ torch.save(checkpoint, save_path)
227
+
228
+ if __name__ == "__main__":
229
+ trainer = BrainAgeTrainer()
230
+ trainer.train()
src/BrainIAC/HD_BET/__pycache__/config.cpython-310.pyc ADDED
Binary file (4.15 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/config.cpython-38.pyc ADDED
Binary file (4.13 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/config.cpython-39.pyc ADDED
Binary file (4.19 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-310.pyc ADDED
Binary file (4.47 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-38.pyc ADDED
Binary file (4.48 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-39.pyc ADDED
Binary file (4.46 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/hd_bet.cpython-310.pyc ADDED
Binary file (4.21 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/hd_bet.cpython-38.pyc ADDED
Binary file (4.27 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-310.pyc ADDED
Binary file (6.78 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-38.pyc ADDED
Binary file (6.89 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-39.pyc ADDED
Binary file (6.84 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/paths.cpython-310.pyc ADDED
Binary file (324 Bytes). View file
 
src/BrainIAC/HD_BET/__pycache__/paths.cpython-38.pyc ADDED
Binary file (335 Bytes). View file
 
src/BrainIAC/HD_BET/__pycache__/paths.cpython-39.pyc ADDED
Binary file (322 Bytes). View file
 
src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-310.pyc ADDED
Binary file (3.68 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-38.pyc ADDED
Binary file (3.67 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-39.pyc ADDED
Binary file (3.68 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/run.cpython-310.pyc ADDED
Binary file (3.83 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/run.cpython-38.pyc ADDED
Binary file (3.88 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/run.cpython-39.pyc ADDED
Binary file (3.85 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/utils.cpython-310.pyc ADDED
Binary file (4.68 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/utils.cpython-38.pyc ADDED
Binary file (4.85 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/utils.cpython-39.pyc ADDED
Binary file (4.81 kB). View file
 
src/BrainIAC/HD_BET/config.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from HD_BET.utils import SetNetworkToVal, softmax_helper
4
+ from abc import abstractmethod
5
+ from HD_BET.network_architecture import Network
6
+
7
+
8
+ class BaseConfig(object):
9
+ def __init__(self):
10
+ pass
11
+
12
+ @abstractmethod
13
+ def get_split(self, fold, random_state=12345):
14
+ pass
15
+
16
+ @abstractmethod
17
+ def get_network(self, mode="train"):
18
+ pass
19
+
20
+ @abstractmethod
21
+ def get_basic_generators(self, fold):
22
+ pass
23
+
24
+ @abstractmethod
25
+ def get_data_generators(self, fold):
26
+ pass
27
+
28
+ def preprocess(self, data):
29
+ return data
30
+
31
+ def __repr__(self):
32
+ res = ""
33
+ for v in vars(self):
34
+ if not v.startswith("__") and not v.startswith("_") and v != 'dataset':
35
+ res += (v + ": " + str(self.__getattribute__(v)) + "\n")
36
+ return res
37
+
38
+
39
+ class HD_BET_Config(BaseConfig):
40
+ def __init__(self):
41
+ super(HD_BET_Config, self).__init__()
42
+
43
+ self.EXPERIMENT_NAME = self.__class__.__name__ # just a generic experiment name
44
+
45
+ # network parameters
46
+ self.net_base_num_layers = 21
47
+ self.BATCH_SIZE = 2
48
+ self.net_do_DS = True
49
+ self.net_dropout_p = 0.0
50
+ self.net_use_inst_norm = True
51
+ self.net_conv_use_bias = True
52
+ self.net_norm_use_affine = True
53
+ self.net_leaky_relu_slope = 1e-1
54
+
55
+ # hyperparameters
56
+ self.INPUT_PATCH_SIZE = (128, 128, 128)
57
+ self.num_classes = 2
58
+ self.selected_data_channels = range(1)
59
+
60
+ # data augmentation
61
+ self.da_mirror_axes = (2, 3, 4)
62
+
63
+ # validation
64
+ self.val_use_DO = False
65
+ self.val_use_train_mode = False # for dropout sampling
66
+ self.val_num_repeats = 1 # only useful if dropout sampling
67
+ self.val_batch_size = 1 # only useful if dropout sampling
68
+ self.val_save_npz = True
69
+ self.val_do_mirroring = True # test time data augmentation via mirroring
70
+ self.val_write_images = True
71
+ self.net_input_must_be_divisible_by = 16 # we could make a network class that has this as a property
72
+ self.val_min_size = self.INPUT_PATCH_SIZE
73
+ self.val_fn = None
74
+
75
+ # CAREFUL! THIS IS A HACK TO MAKE PYTORCH 0.3 STATE DICTS COMPATIBLE WITH PYTORCH 0.4 (setting keep_runnings_
76
+ # stats=True but not using them in validation. keep_runnings_stats was True before 0.3 but unused and defaults
77
+ # to false in 0.4)
78
+ self.val_use_moving_averages = False
79
+
80
+ def get_network(self, train=True, pretrained_weights=None):
81
+ net = Network(self.num_classes, len(self.selected_data_channels), self.net_base_num_layers,
82
+ self.net_dropout_p, softmax_helper, self.net_leaky_relu_slope, self.net_conv_use_bias,
83
+ self.net_norm_use_affine, True, self.net_do_DS)
84
+
85
+ if pretrained_weights is not None:
86
+ net.load_state_dict(
87
+ torch.load(pretrained_weights, map_location=lambda storage, loc: storage))
88
+
89
+ if train:
90
+ net.train(True)
91
+ else:
92
+ net.train(False)
93
+ net.apply(SetNetworkToVal(self.val_use_DO, self.val_use_moving_averages))
94
+ net.do_ds = False
95
+
96
+ optimizer = None
97
+ self.lr_scheduler = None
98
+ return net, optimizer
99
+
100
+ def get_data_generators(self, fold):
101
+ pass
102
+
103
+ def get_split(self, fold, random_state=12345):
104
+ pass
105
+
106
+ def get_basic_generators(self, fold):
107
+ pass
108
+
109
+ def on_epoch_end(self, epoch):
110
+ pass
111
+
112
+ def preprocess(self, data):
113
+ data = np.copy(data)
114
+ for c in range(data.shape[0]):
115
+ data[c] -= data[c].mean()
116
+ data[c] /= data[c].std()
117
+ return data
118
+
119
+
120
+ config = HD_BET_Config
121
+
src/BrainIAC/HD_BET/data_loading.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import SimpleITK as sitk
2
+ import numpy as np
3
+ from skimage.transform import resize
4
+
5
+
6
+ def resize_image(image, old_spacing, new_spacing, order=3):
7
+ new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))),
8
+ int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))),
9
+ int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2]))))
10
+ return resize(image, new_shape, order=order, mode='edge', cval=0, anti_aliasing=False)
11
+
12
+
13
+ def preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.5, 0.5)):
14
+ spacing = np.array(itk_image.GetSpacing())[[2, 1, 0]]
15
+ image = sitk.GetArrayFromImage(itk_image).astype(float)
16
+
17
+ assert len(image.shape) == 3, "The image has unsupported number of dimensions. Only 3D images are allowed"
18
+
19
+ if not is_seg:
20
+ if np.any([[i != j] for i, j in zip(spacing, spacing_target)]):
21
+ image = resize_image(image, spacing, spacing_target).astype(np.float32)
22
+
23
+ image -= image.mean()
24
+ image /= image.std()
25
+ else:
26
+ new_shape = (int(np.round(spacing[0] / spacing_target[0] * float(image.shape[0]))),
27
+ int(np.round(spacing[1] / spacing_target[1] * float(image.shape[1]))),
28
+ int(np.round(spacing[2] / spacing_target[2] * float(image.shape[2]))))
29
+ image = resize_segmentation(image, new_shape, 1)
30
+ return image
31
+
32
+
33
+ def load_and_preprocess(mri_file):
34
+ images = {}
35
+ # t1
36
+ images["T1"] = sitk.ReadImage(mri_file)
37
+
38
+ properties_dict = {
39
+ "spacing": images["T1"].GetSpacing(),
40
+ "direction": images["T1"].GetDirection(),
41
+ "size": images["T1"].GetSize(),
42
+ "origin": images["T1"].GetOrigin()
43
+ }
44
+
45
+ for k in images.keys():
46
+ images[k] = preprocess_image(images[k], is_seg=False, spacing_target=(1.5, 1.5, 1.5))
47
+
48
+ properties_dict['size_before_cropping'] = images["T1"].shape
49
+
50
+ imgs = []
51
+ for seq in ['T1']:
52
+ imgs.append(images[seq][None])
53
+ all_data = np.vstack(imgs)
54
+ print("image shape after preprocessing: ", str(all_data[0].shape))
55
+ return all_data, properties_dict
56
+
57
+
58
+ def save_segmentation_nifti(segmentation, dct, out_fname, order=1):
59
+ '''
60
+ segmentation must have the same spacing as the original nifti (for now). segmentation may have been cropped out
61
+ of the original image
62
+
63
+ dct:
64
+ size_before_cropping
65
+ brain_bbox
66
+ size -> this is the original size of the dataset, if the image was not resampled, this is the same as size_before_cropping
67
+ spacing
68
+ origin
69
+ direction
70
+
71
+ :param segmentation:
72
+ :param dct:
73
+ :param out_fname:
74
+ :return:
75
+ '''
76
+ old_size = dct.get('size_before_cropping')
77
+ bbox = dct.get('brain_bbox')
78
+ if bbox is not None:
79
+ seg_old_size = np.zeros(old_size)
80
+ for c in range(3):
81
+ bbox[c][1] = np.min((bbox[c][0] + segmentation.shape[c], old_size[c]))
82
+ seg_old_size[bbox[0][0]:bbox[0][1],
83
+ bbox[1][0]:bbox[1][1],
84
+ bbox[2][0]:bbox[2][1]] = segmentation
85
+ else:
86
+ seg_old_size = segmentation
87
+ if np.any(np.array(seg_old_size) != np.array(dct['size'])[[2, 1, 0]]):
88
+ seg_old_spacing = resize_segmentation(seg_old_size, np.array(dct['size'])[[2, 1, 0]], order=order)
89
+ else:
90
+ seg_old_spacing = seg_old_size
91
+ seg_resized_itk = sitk.GetImageFromArray(seg_old_spacing.astype(np.int32))
92
+ seg_resized_itk.SetSpacing(np.array(dct['spacing'])[[0, 1, 2]])
93
+ seg_resized_itk.SetOrigin(dct['origin'])
94
+ seg_resized_itk.SetDirection(dct['direction'])
95
+ sitk.WriteImage(seg_resized_itk, out_fname)
96
+
97
+
98
+ def resize_segmentation(segmentation, new_shape, order=3, cval=0):
99
+ '''
100
+ Taken from batchgenerators (https://github.com/MIC-DKFZ/batchgenerators) to prevent dependency
101
+
102
+ Resizes a segmentation map. Supports all orders (see skimage documentation). Will transform segmentation map to one
103
+ hot encoding which is resized and transformed back to a segmentation map.
104
+ This prevents interpolation artifacts ([0, 0, 2] -> [0, 1, 2])
105
+ :param segmentation:
106
+ :param new_shape:
107
+ :param order:
108
+ :return:
109
+ '''
110
+ tpe = segmentation.dtype
111
+ unique_labels = np.unique(segmentation)
112
+ assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation"
113
+ if order == 0:
114
+ return resize(segmentation, new_shape, order, mode="constant", cval=cval, clip=True, anti_aliasing=False).astype(tpe)
115
+ else:
116
+ reshaped = np.zeros(new_shape, dtype=segmentation.dtype)
117
+
118
+ for i, c in enumerate(unique_labels):
119
+ reshaped_multihot = resize((segmentation == c).astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False)
120
+ reshaped[reshaped_multihot >= 0.5] = c
121
+ return reshaped
src/BrainIAC/HD_BET/hd_bet.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import sys
5
+ sys.path.append("/mnt/93E8-0534/AIDAN/HDBET/")
6
+ from HD_BET.run import run_hd_bet
7
+ from HD_BET.utils import maybe_mkdir_p, subfiles
8
+ import HD_BET
9
+
10
+ def hd_bet(input_file_or_dir,output_file_or_dir,mode,device,tta,pp=1,save_mask=0,overwrite_existing=1):
11
+
12
+ if output_file_or_dir is None:
13
+ output_file_or_dir = os.path.join(os.path.dirname(input_file_or_dir),
14
+ os.path.basename(input_file_or_dir).split(".")[0] + "_bet")
15
+
16
+
17
+ params_file = os.path.join(HD_BET.__path__[0], "model_final.py")
18
+ config_file = os.path.join(HD_BET.__path__[0], "config.py")
19
+
20
+ assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input"
21
+
22
+ if device == 'cpu':
23
+ pass
24
+ else:
25
+ device = int(device)
26
+
27
+ if os.path.isdir(input_file_or_dir):
28
+ maybe_mkdir_p(output_file_or_dir)
29
+ input_files = subfiles(input_file_or_dir, suffix='_0000.nii.gz', join=False)
30
+
31
+ if len(input_files) == 0:
32
+ raise RuntimeError("input is a folder but no nifti files (.nii.gz) were found in here")
33
+
34
+ output_files = [os.path.join(output_file_or_dir, i) for i in input_files]
35
+ input_files = [os.path.join(input_file_or_dir, i) for i in input_files]
36
+ else:
37
+ if not output_file_or_dir.endswith('.nii.gz'):
38
+ output_file_or_dir += '.nii.gz'
39
+ assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input"
40
+
41
+ output_files = [output_file_or_dir]
42
+ input_files = [input_file_or_dir]
43
+
44
+ if tta == 0:
45
+ tta = False
46
+ elif tta == 1:
47
+ tta = True
48
+ else:
49
+ raise ValueError("Unknown value for tta: %s. Expected: 0 or 1" % str(tta))
50
+
51
+ if overwrite_existing == 0:
52
+ overwrite_existing = False
53
+ elif overwrite_existing == 1:
54
+ overwrite_existing = True
55
+ else:
56
+ raise ValueError("Unknown value for overwrite_existing: %s. Expected: 0 or 1" % str(overwrite_existing))
57
+
58
+ if pp == 0:
59
+ pp = False
60
+ elif pp == 1:
61
+ pp = True
62
+ else:
63
+ raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp))
64
+
65
+ if save_mask == 0:
66
+ save_mask = False
67
+ elif save_mask == 1:
68
+ save_mask = True
69
+ else:
70
+ raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp))
71
+
72
+ run_hd_bet(input_files, output_files, mode, config_file, device, pp, tta, save_mask, overwrite_existing)
73
+
74
+
75
+ if __name__ == "__main__":
76
+ print("\n########################")
77
+ print("If you are using hd-bet, please cite the following paper:")
78
+ print("Isensee F, Schell M, Tursunova I, Brugnara G, Bonekamp D, Neuberger U, Wick A, Schlemmer HP, Heiland S, Wick W,"
79
+ "Bendszus M, Maier-Hein KH, Kickingereder P. Automated brain extraction of multi-sequence MRI using artificial"
80
+ "neural networks. arXiv preprint arXiv:1901.11341, 2019.")
81
+ print("########################\n")
82
+
83
+ import argparse
84
+ parser = argparse.ArgumentParser()
85
+ parser.add_argument('-i', '--input', help='input. Can be either a single file name or an input folder. If file: must be '
86
+ 'nifti (.nii.gz) and can only be 3D. No support for 4d images, use fslsplit to '
87
+ 'split 4d sequences into 3d images. If folder: all files ending with .nii.gz '
88
+ 'within that folder will be brain extracted.', required=True, type=str)
89
+ parser.add_argument('-o', '--output', help='output. Can be either a filename or a folder. If it does not exist, the folder'
90
+ ' will be created', required=False, type=str)
91
+ parser.add_argument('-mode', type=str, default='accurate', help='can be either \'fast\' or \'accurate\'. Fast will '
92
+ 'use only one set of parameters whereas accurate will '
93
+ 'use the five sets of parameters that resulted from '
94
+ 'our cross-validation as an ensemble. Default: '
95
+ 'accurate',
96
+ required=False)
97
+ parser.add_argument('-device', default='0', type=str, help='used to set on which device the prediction will run. '
98
+ 'Must be either int or str. Use int for GPU id or '
99
+ '\'cpu\' to run on CPU. When using CPU you should '
100
+ 'consider disabling tta. Default for -device is: 0',
101
+ required=False)
102
+ parser.add_argument('-tta', default=1, required=False, type=int, help='whether to use test time data augmentation '
103
+ '(mirroring). 1= True, 0=False. Disable this '
104
+ 'if you are using CPU to speed things up! '
105
+ 'Default: 1')
106
+ parser.add_argument('-pp', default=1, type=int, required=False, help='set to 0 to disabe postprocessing (remove all'
107
+ ' but the largest connected component in '
108
+ 'the prediction. Default: 1')
109
+ parser.add_argument('-s', '--save_mask', default=1, type=int, required=False, help='if set to 0 the segmentation '
110
+ 'mask will not be '
111
+ 'saved')
112
+ parser.add_argument('--overwrite_existing', default=1, type=int, required=False, help="set this to 0 if you don't "
113
+ "want to overwrite existing "
114
+ "predictions")
115
+
116
+ args = parser.parse_args()
117
+
118
+ hd_bet(args.input,args.output,args.mode,args.device,args.tta,args.pp,args.save_mask,args.overwrite_existing)
119
+
src/BrainIAC/HD_BET/network_architecture.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from HD_BET.utils import softmax_helper
5
+
6
+
7
+ class EncodingModule(nn.Module):
8
+ def __init__(self, in_channels, out_channels, filter_size=3, dropout_p=0.3, leakiness=1e-2, conv_bias=True,
9
+ inst_norm_affine=True, lrelu_inplace=True):
10
+ nn.Module.__init__(self)
11
+ self.dropout_p = dropout_p
12
+ self.lrelu_inplace = lrelu_inplace
13
+ self.inst_norm_affine = inst_norm_affine
14
+ self.conv_bias = conv_bias
15
+ self.leakiness = leakiness
16
+ self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
17
+ self.conv1 = nn.Conv3d(in_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias)
18
+ self.dropout = nn.Dropout3d(dropout_p)
19
+ self.bn_2 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
20
+ self.conv2 = nn.Conv3d(out_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias)
21
+
22
+ def forward(self, x):
23
+ skip = x
24
+ x = F.leaky_relu(self.bn_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
25
+ x = self.conv1(x)
26
+ if self.dropout_p is not None and self.dropout_p > 0:
27
+ x = self.dropout(x)
28
+ x = F.leaky_relu(self.bn_2(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
29
+ x = self.conv2(x)
30
+ x = x + skip
31
+ return x
32
+
33
+
34
+ class Upsample(nn.Module):
35
+ def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=True):
36
+ super(Upsample, self).__init__()
37
+ self.align_corners = align_corners
38
+ self.mode = mode
39
+ self.scale_factor = scale_factor
40
+ self.size = size
41
+
42
+ def forward(self, x):
43
+ return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode,
44
+ align_corners=self.align_corners)
45
+
46
+
47
+ class LocalizationModule(nn.Module):
48
+ def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
49
+ lrelu_inplace=True):
50
+ nn.Module.__init__(self)
51
+ self.lrelu_inplace = lrelu_inplace
52
+ self.inst_norm_affine = inst_norm_affine
53
+ self.conv_bias = conv_bias
54
+ self.leakiness = leakiness
55
+ self.conv1 = nn.Conv3d(in_channels, in_channels, 3, 1, 1, bias=self.conv_bias)
56
+ self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
57
+ self.conv2 = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=self.conv_bias)
58
+ self.bn_2 = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True)
59
+
60
+ def forward(self, x):
61
+ x = F.leaky_relu(self.bn_1(self.conv1(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
62
+ x = F.leaky_relu(self.bn_2(self.conv2(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
63
+ return x
64
+
65
+
66
+ class UpsamplingModule(nn.Module):
67
+ def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
68
+ lrelu_inplace=True):
69
+ nn.Module.__init__(self)
70
+ self.lrelu_inplace = lrelu_inplace
71
+ self.inst_norm_affine = inst_norm_affine
72
+ self.conv_bias = conv_bias
73
+ self.leakiness = leakiness
74
+ self.upsample = Upsample(scale_factor=2, mode="trilinear", align_corners=True)
75
+ self.upsample_conv = nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=self.conv_bias)
76
+ self.bn = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True)
77
+
78
+ def forward(self, x):
79
+ x = F.leaky_relu(self.bn(self.upsample_conv(self.upsample(x))), negative_slope=self.leakiness,
80
+ inplace=self.lrelu_inplace)
81
+ return x
82
+
83
+
84
+ class DownsamplingModule(nn.Module):
85
+ def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
86
+ lrelu_inplace=True):
87
+ nn.Module.__init__(self)
88
+ self.lrelu_inplace = lrelu_inplace
89
+ self.inst_norm_affine = inst_norm_affine
90
+ self.conv_bias = conv_bias
91
+ self.leakiness = leakiness
92
+ self.bn = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
93
+ self.downsample = nn.Conv3d(in_channels, out_channels, 3, 2, 1, bias=self.conv_bias)
94
+
95
+ def forward(self, x):
96
+ x = F.leaky_relu(self.bn(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
97
+ b = self.downsample(x)
98
+ return x, b
99
+
100
+
101
+ class Network(nn.Module):
102
+ def __init__(self, num_classes=4, num_input_channels=4, base_filters=16, dropout_p=0.3,
103
+ final_nonlin=softmax_helper, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
104
+ lrelu_inplace=True, do_ds=True):
105
+ super(Network, self).__init__()
106
+
107
+ self.do_ds = do_ds
108
+ self.lrelu_inplace = lrelu_inplace
109
+ self.inst_norm_affine = inst_norm_affine
110
+ self.conv_bias = conv_bias
111
+ self.leakiness = leakiness
112
+ self.final_nonlin = final_nonlin
113
+ self.init_conv = nn.Conv3d(num_input_channels, base_filters, 3, 1, 1, bias=self.conv_bias)
114
+
115
+ self.context1 = EncodingModule(base_filters, base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
116
+ inst_norm_affine=True, lrelu_inplace=True)
117
+ self.down1 = DownsamplingModule(base_filters, base_filters * 2, leakiness=1e-2, conv_bias=True,
118
+ inst_norm_affine=True, lrelu_inplace=True)
119
+
120
+ self.context2 = EncodingModule(2 * base_filters, 2 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
121
+ inst_norm_affine=True, lrelu_inplace=True)
122
+ self.down2 = DownsamplingModule(2 * base_filters, base_filters * 4, leakiness=1e-2, conv_bias=True,
123
+ inst_norm_affine=True, lrelu_inplace=True)
124
+
125
+ self.context3 = EncodingModule(4 * base_filters, 4 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
126
+ inst_norm_affine=True, lrelu_inplace=True)
127
+ self.down3 = DownsamplingModule(4 * base_filters, base_filters * 8, leakiness=1e-2, conv_bias=True,
128
+ inst_norm_affine=True, lrelu_inplace=True)
129
+
130
+ self.context4 = EncodingModule(8 * base_filters, 8 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
131
+ inst_norm_affine=True, lrelu_inplace=True)
132
+ self.down4 = DownsamplingModule(8 * base_filters, base_filters * 16, leakiness=1e-2, conv_bias=True,
133
+ inst_norm_affine=True, lrelu_inplace=True)
134
+
135
+ self.context5 = EncodingModule(16 * base_filters, 16 * base_filters, 3, dropout_p, leakiness=1e-2,
136
+ conv_bias=True, inst_norm_affine=True, lrelu_inplace=True)
137
+
138
+ self.bn_after_context5 = nn.InstanceNorm3d(16 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
139
+ self.up1 = UpsamplingModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True,
140
+ inst_norm_affine=True, lrelu_inplace=True)
141
+
142
+ self.loc1 = LocalizationModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True,
143
+ inst_norm_affine=True, lrelu_inplace=True)
144
+ self.up2 = UpsamplingModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True,
145
+ inst_norm_affine=True, lrelu_inplace=True)
146
+
147
+ self.loc2 = LocalizationModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True,
148
+ inst_norm_affine=True, lrelu_inplace=True)
149
+ self.loc2_seg = nn.Conv3d(4 * base_filters, num_classes, 1, 1, 0, bias=False)
150
+ self.up3 = UpsamplingModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True,
151
+ inst_norm_affine=True, lrelu_inplace=True)
152
+
153
+ self.loc3 = LocalizationModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True,
154
+ inst_norm_affine=True, lrelu_inplace=True)
155
+ self.loc3_seg = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False)
156
+ self.up4 = UpsamplingModule(2 * base_filters, 1 * base_filters, leakiness=1e-2, conv_bias=True,
157
+ inst_norm_affine=True, lrelu_inplace=True)
158
+
159
+ self.end_conv_1 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias)
160
+ self.end_conv_1_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
161
+ self.end_conv_2 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias)
162
+ self.end_conv_2_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
163
+ self.seg_layer = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False)
164
+
165
+ def forward(self, x):
166
+ seg_outputs = []
167
+
168
+ x = self.init_conv(x)
169
+ x = self.context1(x)
170
+
171
+ skip1, x = self.down1(x)
172
+ x = self.context2(x)
173
+
174
+ skip2, x = self.down2(x)
175
+ x = self.context3(x)
176
+
177
+ skip3, x = self.down3(x)
178
+ x = self.context4(x)
179
+
180
+ skip4, x = self.down4(x)
181
+ x = self.context5(x)
182
+
183
+ x = F.leaky_relu(self.bn_after_context5(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
184
+ x = self.up1(x)
185
+
186
+ x = torch.cat((skip4, x), dim=1)
187
+ x = self.loc1(x)
188
+ x = self.up2(x)
189
+
190
+ x = torch.cat((skip3, x), dim=1)
191
+ x = self.loc2(x)
192
+ loc2_seg = self.final_nonlin(self.loc2_seg(x))
193
+ seg_outputs.append(loc2_seg)
194
+ x = self.up3(x)
195
+
196
+ x = torch.cat((skip2, x), dim=1)
197
+ x = self.loc3(x)
198
+ loc3_seg = self.final_nonlin(self.loc3_seg(x))
199
+ seg_outputs.append(loc3_seg)
200
+ x = self.up4(x)
201
+
202
+ x = torch.cat((skip1, x), dim=1)
203
+ x = F.leaky_relu(self.end_conv_1_bn(self.end_conv_1(x)), negative_slope=self.leakiness,
204
+ inplace=self.lrelu_inplace)
205
+ x = F.leaky_relu(self.end_conv_2_bn(self.end_conv_2(x)), negative_slope=self.leakiness,
206
+ inplace=self.lrelu_inplace)
207
+ x = self.final_nonlin(self.seg_layer(x))
208
+ seg_outputs.append(x)
209
+
210
+ if self.do_ds:
211
+ return seg_outputs[::-1]
212
+ else:
213
+ return seg_outputs[-1]
src/BrainIAC/HD_BET/paths.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # please refer to the readme on where to get the parameters. Save them in this folder:
4
+ # Original Path: "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params"
5
+ # Updated path for Docker container:
6
+ folder_with_parameter_files = "/app/BrainIAC/hdbet_model"
src/BrainIAC/HD_BET/predict_case.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def pad_patient_3D(patient, shape_must_be_divisible_by=16, min_size=None):
6
+ if not (isinstance(shape_must_be_divisible_by, list) or isinstance(shape_must_be_divisible_by, tuple)):
7
+ shape_must_be_divisible_by = [shape_must_be_divisible_by] * 3
8
+ shp = patient.shape
9
+ new_shp = [shp[0] + shape_must_be_divisible_by[0] - shp[0] % shape_must_be_divisible_by[0],
10
+ shp[1] + shape_must_be_divisible_by[1] - shp[1] % shape_must_be_divisible_by[1],
11
+ shp[2] + shape_must_be_divisible_by[2] - shp[2] % shape_must_be_divisible_by[2]]
12
+ for i in range(len(shp)):
13
+ if shp[i] % shape_must_be_divisible_by[i] == 0:
14
+ new_shp[i] -= shape_must_be_divisible_by[i]
15
+ if min_size is not None:
16
+ new_shp = np.max(np.vstack((np.array(new_shp), np.array(min_size))), 0)
17
+ return reshape_by_padding_upper_coords(patient, new_shp, 0), shp
18
+
19
+
20
+ def reshape_by_padding_upper_coords(image, new_shape, pad_value=None):
21
+ shape = tuple(list(image.shape))
22
+ new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2,len(shape))), axis=0))
23
+ if pad_value is None:
24
+ if len(shape) == 2:
25
+ pad_value = image[0,0]
26
+ elif len(shape) == 3:
27
+ pad_value = image[0, 0, 0]
28
+ else:
29
+ raise ValueError("Image must be either 2 or 3 dimensional")
30
+ res = np.ones(list(new_shape), dtype=image.dtype) * pad_value
31
+ if len(shape) == 2:
32
+ res[0:0+int(shape[0]), 0:0+int(shape[1])] = image
33
+ elif len(shape) == 3:
34
+ res[0:0+int(shape[0]), 0:0+int(shape[1]), 0:0+int(shape[2])] = image
35
+ return res
36
+
37
+
38
+ def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE=None,
39
+ new_shape_must_be_divisible_by=16, min_size=None, main_device=0, mirror_axes=(2, 3, 4)):
40
+ with torch.no_grad():
41
+ pad_res = []
42
+ for i in range(patient_data.shape[0]):
43
+ t, old_shape = pad_patient_3D(patient_data[i], new_shape_must_be_divisible_by, min_size)
44
+ pad_res.append(t[None])
45
+
46
+ patient_data = np.vstack(pad_res)
47
+
48
+ new_shp = patient_data.shape
49
+
50
+ data = np.zeros(tuple([1] + list(new_shp)), dtype=np.float32)
51
+
52
+ data[0] = patient_data
53
+
54
+ if BATCH_SIZE is not None:
55
+ data = np.vstack([data] * BATCH_SIZE)
56
+
57
+ a = torch.rand(data.shape).float()
58
+
59
+ if main_device == 'cpu':
60
+ pass
61
+ else:
62
+ a = a.cuda(main_device)
63
+
64
+ if do_mirroring:
65
+ x = 8
66
+ else:
67
+ x = 1
68
+ all_preds = []
69
+ for i in range(num_repeats):
70
+ for m in range(x):
71
+ data_for_net = np.array(data)
72
+ do_stuff = False
73
+ if m == 0:
74
+ do_stuff = True
75
+ pass
76
+ if m == 1 and (4 in mirror_axes):
77
+ do_stuff = True
78
+ data_for_net = data_for_net[:, :, :, :, ::-1]
79
+ if m == 2 and (3 in mirror_axes):
80
+ do_stuff = True
81
+ data_for_net = data_for_net[:, :, :, ::-1, :]
82
+ if m == 3 and (4 in mirror_axes) and (3 in mirror_axes):
83
+ do_stuff = True
84
+ data_for_net = data_for_net[:, :, :, ::-1, ::-1]
85
+ if m == 4 and (2 in mirror_axes):
86
+ do_stuff = True
87
+ data_for_net = data_for_net[:, :, ::-1, :, :]
88
+ if m == 5 and (2 in mirror_axes) and (4 in mirror_axes):
89
+ do_stuff = True
90
+ data_for_net = data_for_net[:, :, ::-1, :, ::-1]
91
+ if m == 6 and (2 in mirror_axes) and (3 in mirror_axes):
92
+ do_stuff = True
93
+ data_for_net = data_for_net[:, :, ::-1, ::-1, :]
94
+ if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes):
95
+ do_stuff = True
96
+ data_for_net = data_for_net[:, :, ::-1, ::-1, ::-1]
97
+
98
+ if do_stuff:
99
+ _ = a.data.copy_(torch.from_numpy(np.copy(data_for_net)))
100
+ p = net(a) # np.copy is necessary because ::-1 creates just a view i think
101
+ p = p.data.cpu().numpy()
102
+
103
+ if m == 0:
104
+ pass
105
+ if m == 1 and (4 in mirror_axes):
106
+ p = p[:, :, :, :, ::-1]
107
+ if m == 2 and (3 in mirror_axes):
108
+ p = p[:, :, :, ::-1, :]
109
+ if m == 3 and (4 in mirror_axes) and (3 in mirror_axes):
110
+ p = p[:, :, :, ::-1, ::-1]
111
+ if m == 4 and (2 in mirror_axes):
112
+ p = p[:, :, ::-1, :, :]
113
+ if m == 5 and (2 in mirror_axes) and (4 in mirror_axes):
114
+ p = p[:, :, ::-1, :, ::-1]
115
+ if m == 6 and (2 in mirror_axes) and (3 in mirror_axes):
116
+ p = p[:, :, ::-1, ::-1, :]
117
+ if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes):
118
+ p = p[:, :, ::-1, ::-1, ::-1]
119
+ all_preds.append(p)
120
+
121
+ stacked = np.vstack(all_preds)[:, :, :old_shape[0], :old_shape[1], :old_shape[2]]
122
+ predicted_segmentation = stacked.mean(0).argmax(0)
123
+ uncertainty = stacked.var(0)
124
+ bayesian_predictions = stacked
125
+ softmax_pred = stacked.mean(0)
126
+ return predicted_segmentation, bayesian_predictions, softmax_pred, uncertainty
src/BrainIAC/HD_BET/run.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import SimpleITK as sitk
4
+ from HD_BET.data_loading import load_and_preprocess, save_segmentation_nifti
5
+ from HD_BET.predict_case import predict_case_3D_net
6
+ import imp
7
+ from HD_BET.utils import postprocess_prediction, SetNetworkToVal, get_params_fname, maybe_download_parameters
8
+ import os
9
+ import HD_BET
10
+
11
+
12
+ def apply_bet(img, bet, out_fname):
13
+ img_itk = sitk.ReadImage(img)
14
+ img_npy = sitk.GetArrayFromImage(img_itk)
15
+ img_bet = sitk.GetArrayFromImage(sitk.ReadImage(bet))
16
+ img_npy[img_bet == 0] = 0
17
+ out = sitk.GetImageFromArray(img_npy)
18
+ out.CopyInformation(img_itk)
19
+ sitk.WriteImage(out, out_fname)
20
+
21
+
22
+ def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.join(HD_BET.__path__[0], "config.py"), device=0,
23
+ postprocess=False, do_tta=True, keep_mask=True, overwrite=True):
24
+ """
25
+
26
+ :param mri_fnames: str or list/tuple of str
27
+ :param output_fnames: str or list/tuple of str. If list: must have the same length as output_fnames
28
+ :param mode: fast or accurate
29
+ :param config_file: config.py
30
+ :param device: either int (for device id) or 'cpu'
31
+ :param postprocess: whether to do postprocessing or not. Postprocessing here consists of simply discarding all
32
+ but the largest predicted connected component. Default False
33
+ :param do_tta: whether to do test time data augmentation by mirroring along all axes. Default: True. If you use
34
+ CPU you may want to turn that off to speed things up
35
+ :return:
36
+ """
37
+
38
+ list_of_param_files = []
39
+
40
+ if mode == 'fast':
41
+ params_file = get_params_fname(0)
42
+ maybe_download_parameters(0)
43
+
44
+ list_of_param_files.append(params_file)
45
+ elif mode == 'accurate':
46
+ for i in range(5):
47
+ params_file = get_params_fname(i)
48
+ maybe_download_parameters(i)
49
+
50
+ list_of_param_files.append(params_file)
51
+ else:
52
+ raise ValueError("Unknown value for mode: %s. Expected: fast or accurate" % mode)
53
+
54
+ assert all([os.path.isfile(i) for i in list_of_param_files]), "Could not find parameter files"
55
+
56
+ cf = imp.load_source('cf', config_file)
57
+ cf = cf.config()
58
+
59
+ net, _ = cf.get_network(cf.val_use_train_mode, None)
60
+ if device == "cpu":
61
+ net = net.cpu()
62
+ else:
63
+ net.cuda(device)
64
+
65
+ if not isinstance(mri_fnames, (list, tuple)):
66
+ mri_fnames = [mri_fnames]
67
+
68
+ if not isinstance(output_fnames, (list, tuple)):
69
+ output_fnames = [output_fnames]
70
+
71
+ assert len(mri_fnames) == len(output_fnames), "mri_fnames and output_fnames must have the same length"
72
+
73
+ params = []
74
+ for p in list_of_param_files:
75
+ params.append(torch.load(p, map_location=lambda storage, loc: storage))
76
+
77
+ for in_fname, out_fname in zip(mri_fnames, output_fnames):
78
+ mask_fname = out_fname[:-7] + "_mask.nii.gz"
79
+ if overwrite or (not (os.path.isfile(mask_fname) and keep_mask) or not os.path.isfile(out_fname)):
80
+ print("File:", in_fname)
81
+ print("preprocessing...")
82
+ try:
83
+ data, data_dict = load_and_preprocess(in_fname)
84
+ except RuntimeError:
85
+ print("\nERROR\nCould not read file", in_fname, "\n")
86
+ continue
87
+ except AssertionError as e:
88
+ print(e)
89
+ continue
90
+
91
+ softmax_preds = []
92
+
93
+ print("prediction (CNN id)...")
94
+ for i, p in enumerate(params):
95
+ print(i)
96
+ net.load_state_dict(p)
97
+ net.eval()
98
+ net.apply(SetNetworkToVal(False, False))
99
+ _, _, softmax_pred, _ = predict_case_3D_net(net, data, do_tta, cf.val_num_repeats,
100
+ cf.val_batch_size, cf.net_input_must_be_divisible_by,
101
+ cf.val_min_size, device, cf.da_mirror_axes)
102
+ softmax_preds.append(softmax_pred[None])
103
+
104
+ seg = np.argmax(np.vstack(softmax_preds).mean(0), 0)
105
+
106
+ if postprocess:
107
+ seg = postprocess_prediction(seg)
108
+
109
+ print("exporting segmentation...")
110
+ save_segmentation_nifti(seg, data_dict, mask_fname)
111
+
112
+ apply_bet(in_fname, mask_fname, out_fname)
113
+
114
+ if not keep_mask:
115
+ os.remove(mask_fname)
116
+
117
+
src/BrainIAC/HD_BET/utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from urllib.request import urlopen
2
+ import torch
3
+ from torch import nn
4
+ import numpy as np
5
+ from skimage.morphology import label
6
+ import os
7
+ from HD_BET.paths import folder_with_parameter_files
8
+
9
+
10
+ def get_params_fname(fold):
11
+ return os.path.join(folder_with_parameter_files, "%d.model" % fold)
12
+
13
+
14
+ def maybe_download_parameters(fold=0, force_overwrite=False):
15
+ """
16
+ Downloads the parameters for some fold if it is not present yet.
17
+ :param fold:
18
+ :param force_overwrite: if True the old parameter file will be deleted (if present) prior to download
19
+ :return:
20
+ """
21
+
22
+ assert 0 <= fold <= 4, "fold must be between 0 and 4"
23
+
24
+ if not os.path.isdir(folder_with_parameter_files):
25
+ maybe_mkdir_p(folder_with_parameter_files)
26
+
27
+ out_filename = get_params_fname(fold)
28
+
29
+ if force_overwrite and os.path.isfile(out_filename):
30
+ os.remove(out_filename)
31
+
32
+ if not os.path.isfile(out_filename):
33
+ url = "https://zenodo.org/record/2540695/files/%d.model?download=1" % fold
34
+ print("Downloading", url, "...")
35
+ data = urlopen(url).read()
36
+ #out_filename = "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params/0.model"
37
+ with open(out_filename, 'wb') as f:
38
+ f.write(data)
39
+
40
+
41
+ def init_weights(module):
42
+ if isinstance(module, nn.Conv3d):
43
+ module.weight = nn.init.kaiming_normal(module.weight, a=1e-2)
44
+ if module.bias is not None:
45
+ module.bias = nn.init.constant(module.bias, 0)
46
+
47
+
48
+ def softmax_helper(x):
49
+ rpt = [1 for _ in range(len(x.size()))]
50
+ rpt[1] = x.size(1)
51
+ x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
52
+ e_x = torch.exp(x - x_max)
53
+ return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
54
+
55
+
56
+ class SetNetworkToVal(object):
57
+ def __init__(self, use_dropout_sampling=False, norm_use_average=True):
58
+ self.norm_use_average = norm_use_average
59
+ self.use_dropout_sampling = use_dropout_sampling
60
+
61
+ def __call__(self, module):
62
+ if isinstance(module, nn.Dropout3d) or isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout):
63
+ module.train(self.use_dropout_sampling)
64
+ elif isinstance(module, nn.InstanceNorm3d) or isinstance(module, nn.InstanceNorm2d) or \
65
+ isinstance(module, nn.InstanceNorm1d) \
66
+ or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or \
67
+ isinstance(module, nn.BatchNorm1d):
68
+ module.train(not self.norm_use_average)
69
+
70
+
71
+ def postprocess_prediction(seg):
72
+ # basically look for connected components and choose the largest one, delete everything else
73
+ print("running postprocessing... ")
74
+ mask = seg != 0
75
+ lbls = label(mask, connectivity=mask.ndim)
76
+ lbls_sizes = [np.sum(lbls == i) for i in np.unique(lbls)]
77
+ largest_region = np.argmax(lbls_sizes[1:]) + 1
78
+ seg[lbls != largest_region] = 0
79
+ return seg
80
+
81
+
82
+ def subdirs(folder, join=True, prefix=None, suffix=None, sort=True):
83
+ if join:
84
+ l = os.path.join
85
+ else:
86
+ l = lambda x, y: y
87
+ res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))
88
+ and (prefix is None or i.startswith(prefix))
89
+ and (suffix is None or i.endswith(suffix))]
90
+ if sort:
91
+ res.sort()
92
+ return res
93
+
94
+
95
+ def subfiles(folder, join=True, prefix=None, suffix=None, sort=True):
96
+ if join:
97
+ l = os.path.join
98
+ else:
99
+ l = lambda x, y: y
100
+ res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i))
101
+ and (prefix is None or i.startswith(prefix))
102
+ and (suffix is None or i.endswith(suffix))]
103
+ if sort:
104
+ res.sort()
105
+ return res
106
+
107
+
108
+ subfolders = subdirs # I am tired of confusing those
109
+
110
+
111
+ def maybe_mkdir_p(directory):
112
+ splits = directory.split("/")[1:]
113
+ for i in range(0, len(splits)):
114
+ if not os.path.isdir(os.path.join("", *splits[:i+1])):
115
+ os.mkdir(os.path.join("", *splits[:i+1]))
src/BrainIAC/IDHprediction/README.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IDH Mutation Classification
2
+
3
+ <p align="left">
4
+ <img src="idh.jpeg" width="200" alt="IDH Mutation Classification Example"/>
5
+ </p>
6
+
7
+ ## Overview
8
+
9
+ We present the IDH mutation classification training and inference code for BrainIAC as a downstream task. The pipeline is trained and infered on T1CE and FLAIR scans, with AUC and F1 as evaluation metric.
10
+
11
+ ## Data Requirements
12
+
13
+ - **Input**: T1CE and FLAIR MR sequences from a single scan
14
+ - **Format**: NIFTI (.nii.gz)
15
+ - **Preprocessing**: Bias field corrected, registered to standard space, skull stripped
16
+ - **CSV Structure**:
17
+ ```
18
+ pat_id,scandate,label
19
+ subject001,scan_sequence,1 # 1 for IDH mutant, 0 for wildtype
20
+ ```
21
+ refer to [ quickstart.ipynb](../quickstart.ipynb) to find how to preprocess data and generate csv file.
22
+
23
+ ## Setup
24
+
25
+ 1. **Configuration**:
26
+ change the [config.yml](../config.yml) file accordingly.
27
+ ```yaml
28
+ # config.yml
29
+ data:
30
+ train_csv: "path/to/train.csv"
31
+ val_csv: "path/to/val.csv"
32
+ test_csv: "path/to/test.csv"
33
+ root_dir: "../data/sample/processed"
34
+ collate: 2 # two sequence pipeline
35
+
36
+ checkpoints: "./checkpoints/idh_model.00" # for inference/testing
37
+
38
+ train:
39
+ finetune: 'yes' # yes to finetune the entire model
40
+ freeze: 'no' # yes to freeze the resnet backbone
41
+ weights: ./checkpoints/brainiac.ckpt # path to brainiac weights
42
+ ```
43
+
44
+ 2. **Training**:
45
+ ```bash
46
+ python -m IDHprediction.train_idh
47
+ ```
48
+
49
+ 3. **Inference**:
50
+ ```bash
51
+ python -m IDHprediction.infer_idh
52
+ ```
53
+
src/BrainIAC/IDHprediction/__init__.py ADDED
File without changes
src/BrainIAC/IDHprediction/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (150 Bytes). View file
 
src/BrainIAC/IDHprediction/__pycache__/infer_idh.cpython-39.pyc ADDED
Binary file (4.82 kB). View file
 
src/BrainIAC/IDHprediction/idh.jpeg ADDED