File size: 7,991 Bytes
38fd365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
# This config file implements the training workflow. It can be combined with multi_gpu_train.yaml to use DDP for
# multi-GPU runs.

imports:
- $import os
- $import json
- $import datetime
- $import torch
- $import glob

# pull out some constants from MONAI
image: $monai.utils.CommonKeys.IMAGE
label: $monai.utils.CommonKeys.LABEL
pred: $monai.utils.CommonKeys.PRED

# multi-gpu values, `rank` will be replaced in a separate script implementing multi-gpu changes
rank: 0  # without multi-gpu support consider the process as rank 0 anyway
is_not_rank0: '$@rank > 0'  # true if not main process, used to disable handlers for other ranks

# hyperparameters for you to modify on the command line
val_interval: 1  # how often to perform validation after an epoch
ckpt_interval: 1  # how often to save a checkpoint after an epoch
rand_prob: 0.5  # probability a random transform is applied
batch_size: 5  # number of images per batch
num_epochs: 10  # number of epochs to train for
num_substeps: 1  # how many times to repeatly train with the same batch
num_workers: 4  # number of workers to generate batches with
learning_rate: 0.001  # initial learning rate
num_classes: 4  # number of classes in training data which network should predict
device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# define various paths
bundle_root: .  # root directory of the bundle
ckpt_path: $@bundle_root + '/models/model.pt'  # checkpoint to load before starting
dataset_dir: $@bundle_root + '/data/train_data'  # where data is coming from
results_dir: $@bundle_root + '/results'  # where results are being stored to
# a new output directory is chosen using a timestamp for every invocation
output_dir: '$datetime.datetime.now().strftime(@results_dir + ''/output_%y%m%d_%H%M%S'')'

# network definition, this could be parameterised by pre-defined values or on the command line
network_def:
  _target_: DenseNet121
  spatial_dims: 2
  in_channels: 1
  out_channels: '@num_classes'
network: $@network_def.to(@device)

# dataset value, this assumes a JOSN file filled with img##.nii.gz file and label
data_json: $@bundle_root + '/data/train_samples.json'  # where training data is located and label
data_fp: "$open(@data_json,'r', encoding='utf8')"
data_dict: "$json.load(@data_fp)"
partitions: '$monai.data.partition_dataset(@data_dict, (4, 1), shuffle=True, seed=0)'
train_sub: '$@partitions[0]' # train partition
val_sub: '$@partitions[1]' # validation partition

# these transforms are used for training and validation transform sequences
base_transforms:
- _target_: LoadImaged
  keys: '@image'
- _target_: EnsureChannelFirstd
  keys: '@image'

# these are the random and regularising transforms used only for training
train_transforms:
- _target_: RandAxisFlipd
  keys: '@image'
  prob: '@rand_prob'
- _target_: RandRotate90d
  keys: '@image'
  prob: '@rand_prob'
- _target_: RandGaussianNoised
  keys: '@image'
  prob: '@rand_prob'
  std: 0.05
- _target_: ScaleIntensityd
  keys: '@image'

# these are used for validation data so no randomness
val_transforms:
- _target_: ScaleIntensityd
  keys: '@image'

# define the Compose objects for training and validation
preprocessing:
  _target_: Compose
  transforms: $@base_transforms + @train_transforms

val_preprocessing:
  _target_: Compose
  transforms: $@base_transforms + @val_transforms

# define the datasets for training and validation
train_dataset:
  _target_: Dataset
  data: '@train_sub'
  transform: '@preprocessing'

val_dataset:
  _target_: Dataset
  data: '@val_sub'
  transform: '@val_preprocessing'

# define the dataloaders for training and validation
train_dataloader:
  _target_: ThreadDataLoader  # generate data ansynchronously from training
  dataset: '@train_dataset'
  batch_size: '@batch_size'
  repeats: '@num_substeps'
  num_workers: '@num_workers'

val_dataloader:
  _target_: DataLoader  # faster transforms probably won't benefit from threading
  dataset: '@val_dataset'
  batch_size: '@batch_size'
  num_workers: '@num_workers'

# Simple CrossEntropy loss configured for multi-class classification
lossfn:
  _target_: torch.nn.CrossEntropyLoss
  reduction: sum

# hyperparameters could be added for other arguments of this class
optimizer:
  _target_: torch.optim.Adam
  params: [email protected]()
  lr: '@learning_rate'

# should be replaced with other inferer types if training process is different for your network
inferer:
  _target_: SimpleInferer

# transform to apply to data from network to be suitable for validation
postprocessing:
  _target_: Compose
  transforms:
  - _target_: Activationsd
    keys: '@pred'
    softmax: true
  - _target_: AsDiscreted
    keys: ['@pred', '@label']
    argmax: [true, false]
    to_onehot: '@num_classes'
  - _target_: ToTensord
    keys: ['@pred', '@label']
    device: '@device'

# validation handlers to gather statistics, log these to a file, and save best checkpoint
val_handlers:
- _target_: StatsHandler
  name: null  # use engine.logger as the Logger object to log to
  output_transform: '$lambda x: None'
- _target_: LogfileHandler  # log outputs from the validation engine
  output_dir: '@output_dir'
- _target_: CheckpointSaver
  _disabled_: '@is_not_rank0'  # only need rank 0 to save
  save_dir: '@output_dir'
  save_dict:
    model: '@network'
  save_interval: 0  # don't save iterations, just when the metric improves
  save_final: false
  epoch_level: false
  save_key_metric: true
  key_metric_name: val_accuracy  # save the checkpoint when this value improves

# engine for running validation, ties together objects defined above and has metric definitions
evaluator:
  _target_: SupervisedEvaluator
  device: '@device'
  val_data_loader: '@val_dataloader'
  network: '@network'
  postprocessing: '@postprocessing'
  key_val_metric:
    val_accuracy:
      _target_: ignite.metrics.Accuracy
      output_transform: $monai.handlers.from_engine([@pred, @label])
  additional_metrics:
    val_f1:  # can have other metrics
      _target_: ConfusionMatrix
      metric_name: 'f1 score'
      output_transform: $monai.handlers.from_engine([@pred, @label])
  val_handlers: '@val_handlers'

# gathers the loss and validation values for each iteration, referred to by CheckpointSaver so defined separately
metriclogger:
  _target_: MetricLogger
  evaluator: '@evaluator'

handlers:
- '@metriclogger'
- _target_: CheckpointLoader
  _disabled_: $not os.path.exists(@ckpt_path)
  load_path: '@ckpt_path'
  load_dict:
    model: '@network'
- _target_: ValidationHandler  # run validation at the set interval, bridge between trainer and evaluator objects
  validator: '@evaluator'
  epoch_level: true
  interval: '@val_interval'
- _target_: CheckpointSaver
  _disabled_: '@is_not_rank0'  # only need rank 0 to save
  save_dir: '@output_dir'
  save_dict:  # every epoch checkpoint saves the network and the metric logger in a dictionary
    model: '@network'
    logger: '@metriclogger'
  save_interval: '@ckpt_interval'
  save_final: true
  epoch_level: true
- _target_: StatsHandler
  name: null  # use engine.logger as the Logger object to log to
  tag_name: train_loss
  output_transform: $monai.handlers.from_engine(['loss'], first=True)  # log loss value
- _target_: LogfileHandler  # log outputs from the training engine
  output_dir: '@output_dir'

# engine for training, ties values defined above together into the main engine for the training process
trainer:
  _target_: SupervisedTrainer
  max_epochs: '@num_epochs'
  device: '@device'
  train_data_loader: '@train_dataloader'
  network: '@network'
  inferer: '@inferer'  # unnecessary since SimpleInferer is the default if this isn't provided
  loss_function: '@lossfn'
  optimizer: '@optimizer'
  # postprocessing: '@postprocessing'  # uncomment if you have train metrics that need post-processing
  key_train_metric: null
  train_handlers: '@handlers'

initialize:
- "$monai.utils.set_determinism(seed=123)"
- "$setattr(torch.backends.cudnn, 'benchmark', True)"
run:
- "[email protected]()"