project-monai's picture
Upload classification_template version 0.0.3
38fd365 verified
# This implements the workflow for applying the network to a directory of images and measuring network performance with metrics.
imports:
- $import os
- $import json
- $import torch
- $import glob
# pull out some constants from MONAI
image: $monai.utils.CommonKeys.IMAGE
label: $monai.utils.CommonKeys.LABEL
pred: $monai.utils.CommonKeys.PRED
# hyperparameters for you to modify on the command line
batch_size: 1 # number of images per batch
num_workers: 0 # number of workers to generate batches with
num_classes: 4 # number of classes in training data which network should predict
device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# define various paths
bundle_root: . # root directory of the bundle
ckpt_path: $@bundle_root + '/models/model.pt' # checkpoint to load before starting
dataset_dir: $@bundle_root + '/data/test_data' # where data is coming from
# network definition, this could be parameterised by pre-defined values or on the command line
network_def:
_target_: DenseNet121
spatial_dims: 2
in_channels: 1
out_channels: '@num_classes'
network: $@network_def.to(@device)
# list all niftis in the input directory
test_json: "$@bundle_root+'/data/test_samples.json'"
test_fp: "$open(@test_json,'r', encoding='utf8')"
# load json file
test_dict: "$json.load(@test_fp)"
# these transforms are used for inference to load and regularise inputs
transforms:
- _target_: LoadImaged
keys: '@image'
- _target_: EnsureChannelFirstd
keys: '@image'
- _target_: ScaleIntensityd
keys: '@image'
preprocessing:
_target_: Compose
transforms: $@transforms
dataset:
_target_: Dataset
data: '@test_dict'
transform: '@preprocessing'
dataloader:
_target_: ThreadDataLoader # generate data ansynchronously from inference
dataset: '@dataset'
batch_size: '@batch_size'
num_workers: '@num_workers'
# should be replaced with other inferer types if training process is different for your network
inferer:
_target_: SimpleInferer
# transform to apply to data from network to be suitable for validation
postprocessing:
_target_: Compose
transforms:
- _target_: Activationsd
keys: '@pred'
softmax: true
- _target_: AsDiscreted
keys: ['@pred', '@label']
argmax: [true, false]
to_onehot: '@num_classes'
- _target_: ToTensord
keys: ['@pred', '@label']
device: '@device'
# inference handlers to load checkpoint, gather statistics
val_handlers:
- _target_: CheckpointLoader
_disabled_: $not os.path.exists(@ckpt_path)
load_path: '@ckpt_path'
load_dict:
model: '@network'
- _target_: StatsHandler
name: null # use engine.logger as the Logger object to log to
output_transform: '$lambda x: None'
# engine for running inference, ties together objects defined above and has metric definitions
evaluator:
_target_: SupervisedEvaluator
device: '@device'
val_data_loader: '@dataloader'
network: '@network'
inferer: '@inferer'
postprocessing: '@postprocessing'
key_val_metric:
val_accuracy:
_target_: ignite.metrics.Accuracy
output_transform: $monai.handlers.from_engine([@pred, @label])
additional_metrics:
val_f1: # can have other metrics
_target_: ConfusionMatrix
metric_name: 'f1 score'
output_transform: $monai.handlers.from_engine([@pred, @label])
val_handlers: '@val_handlers'
initialize:
- "$setattr(torch.backends.cudnn, 'benchmark', True)"
run:
- "[email protected]()"