File size: 7,991 Bytes
38fd365 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
# This config file implements the training workflow. It can be combined with multi_gpu_train.yaml to use DDP for
# multi-GPU runs.
imports:
- $import os
- $import json
- $import datetime
- $import torch
- $import glob
# pull out some constants from MONAI
image: $monai.utils.CommonKeys.IMAGE
label: $monai.utils.CommonKeys.LABEL
pred: $monai.utils.CommonKeys.PRED
# multi-gpu values, `rank` will be replaced in a separate script implementing multi-gpu changes
rank: 0 # without multi-gpu support consider the process as rank 0 anyway
is_not_rank0: '$@rank > 0' # true if not main process, used to disable handlers for other ranks
# hyperparameters for you to modify on the command line
val_interval: 1 # how often to perform validation after an epoch
ckpt_interval: 1 # how often to save a checkpoint after an epoch
rand_prob: 0.5 # probability a random transform is applied
batch_size: 5 # number of images per batch
num_epochs: 10 # number of epochs to train for
num_substeps: 1 # how many times to repeatly train with the same batch
num_workers: 4 # number of workers to generate batches with
learning_rate: 0.001 # initial learning rate
num_classes: 4 # number of classes in training data which network should predict
device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# define various paths
bundle_root: . # root directory of the bundle
ckpt_path: $@bundle_root + '/models/model.pt' # checkpoint to load before starting
dataset_dir: $@bundle_root + '/data/train_data' # where data is coming from
results_dir: $@bundle_root + '/results' # where results are being stored to
# a new output directory is chosen using a timestamp for every invocation
output_dir: '$datetime.datetime.now().strftime(@results_dir + ''/output_%y%m%d_%H%M%S'')'
# network definition, this could be parameterised by pre-defined values or on the command line
network_def:
_target_: DenseNet121
spatial_dims: 2
in_channels: 1
out_channels: '@num_classes'
network: $@network_def.to(@device)
# dataset value, this assumes a JOSN file filled with img##.nii.gz file and label
data_json: $@bundle_root + '/data/train_samples.json' # where training data is located and label
data_fp: "$open(@data_json,'r', encoding='utf8')"
data_dict: "$json.load(@data_fp)"
partitions: '$monai.data.partition_dataset(@data_dict, (4, 1), shuffle=True, seed=0)'
train_sub: '$@partitions[0]' # train partition
val_sub: '$@partitions[1]' # validation partition
# these transforms are used for training and validation transform sequences
base_transforms:
- _target_: LoadImaged
keys: '@image'
- _target_: EnsureChannelFirstd
keys: '@image'
# these are the random and regularising transforms used only for training
train_transforms:
- _target_: RandAxisFlipd
keys: '@image'
prob: '@rand_prob'
- _target_: RandRotate90d
keys: '@image'
prob: '@rand_prob'
- _target_: RandGaussianNoised
keys: '@image'
prob: '@rand_prob'
std: 0.05
- _target_: ScaleIntensityd
keys: '@image'
# these are used for validation data so no randomness
val_transforms:
- _target_: ScaleIntensityd
keys: '@image'
# define the Compose objects for training and validation
preprocessing:
_target_: Compose
transforms: $@base_transforms + @train_transforms
val_preprocessing:
_target_: Compose
transforms: $@base_transforms + @val_transforms
# define the datasets for training and validation
train_dataset:
_target_: Dataset
data: '@train_sub'
transform: '@preprocessing'
val_dataset:
_target_: Dataset
data: '@val_sub'
transform: '@val_preprocessing'
# define the dataloaders for training and validation
train_dataloader:
_target_: ThreadDataLoader # generate data ansynchronously from training
dataset: '@train_dataset'
batch_size: '@batch_size'
repeats: '@num_substeps'
num_workers: '@num_workers'
val_dataloader:
_target_: DataLoader # faster transforms probably won't benefit from threading
dataset: '@val_dataset'
batch_size: '@batch_size'
num_workers: '@num_workers'
# Simple CrossEntropy loss configured for multi-class classification
lossfn:
_target_: torch.nn.CrossEntropyLoss
reduction: sum
# hyperparameters could be added for other arguments of this class
optimizer:
_target_: torch.optim.Adam
params: [email protected]()
lr: '@learning_rate'
# should be replaced with other inferer types if training process is different for your network
inferer:
_target_: SimpleInferer
# transform to apply to data from network to be suitable for validation
postprocessing:
_target_: Compose
transforms:
- _target_: Activationsd
keys: '@pred'
softmax: true
- _target_: AsDiscreted
keys: ['@pred', '@label']
argmax: [true, false]
to_onehot: '@num_classes'
- _target_: ToTensord
keys: ['@pred', '@label']
device: '@device'
# validation handlers to gather statistics, log these to a file, and save best checkpoint
val_handlers:
- _target_: StatsHandler
name: null # use engine.logger as the Logger object to log to
output_transform: '$lambda x: None'
- _target_: LogfileHandler # log outputs from the validation engine
output_dir: '@output_dir'
- _target_: CheckpointSaver
_disabled_: '@is_not_rank0' # only need rank 0 to save
save_dir: '@output_dir'
save_dict:
model: '@network'
save_interval: 0 # don't save iterations, just when the metric improves
save_final: false
epoch_level: false
save_key_metric: true
key_metric_name: val_accuracy # save the checkpoint when this value improves
# engine for running validation, ties together objects defined above and has metric definitions
evaluator:
_target_: SupervisedEvaluator
device: '@device'
val_data_loader: '@val_dataloader'
network: '@network'
postprocessing: '@postprocessing'
key_val_metric:
val_accuracy:
_target_: ignite.metrics.Accuracy
output_transform: $monai.handlers.from_engine([@pred, @label])
additional_metrics:
val_f1: # can have other metrics
_target_: ConfusionMatrix
metric_name: 'f1 score'
output_transform: $monai.handlers.from_engine([@pred, @label])
val_handlers: '@val_handlers'
# gathers the loss and validation values for each iteration, referred to by CheckpointSaver so defined separately
metriclogger:
_target_: MetricLogger
evaluator: '@evaluator'
handlers:
- '@metriclogger'
- _target_: CheckpointLoader
_disabled_: $not os.path.exists(@ckpt_path)
load_path: '@ckpt_path'
load_dict:
model: '@network'
- _target_: ValidationHandler # run validation at the set interval, bridge between trainer and evaluator objects
validator: '@evaluator'
epoch_level: true
interval: '@val_interval'
- _target_: CheckpointSaver
_disabled_: '@is_not_rank0' # only need rank 0 to save
save_dir: '@output_dir'
save_dict: # every epoch checkpoint saves the network and the metric logger in a dictionary
model: '@network'
logger: '@metriclogger'
save_interval: '@ckpt_interval'
save_final: true
epoch_level: true
- _target_: StatsHandler
name: null # use engine.logger as the Logger object to log to
tag_name: train_loss
output_transform: $monai.handlers.from_engine(['loss'], first=True) # log loss value
- _target_: LogfileHandler # log outputs from the training engine
output_dir: '@output_dir'
# engine for training, ties values defined above together into the main engine for the training process
trainer:
_target_: SupervisedTrainer
max_epochs: '@num_epochs'
device: '@device'
train_data_loader: '@train_dataloader'
network: '@network'
inferer: '@inferer' # unnecessary since SimpleInferer is the default if this isn't provided
loss_function: '@lossfn'
optimizer: '@optimizer'
# postprocessing: '@postprocessing' # uncomment if you have train metrics that need post-processing
key_train_metric: null
train_handlers: '@handlers'
initialize:
- "$monai.utils.set_determinism(seed=123)"
- "$setattr(torch.backends.cudnn, 'benchmark', True)"
run:
- "[email protected]()"
|