File size: 3,433 Bytes
38fd365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# This implements the workflow for applying the network to a directory of images and measuring network performance with metrics.

imports:
- $import os
- $import json
- $import torch
- $import glob

# pull out some constants from MONAI
image: $monai.utils.CommonKeys.IMAGE
label: $monai.utils.CommonKeys.LABEL
pred: $monai.utils.CommonKeys.PRED

# hyperparameters for you to modify on the command line
batch_size: 1  # number of images per batch
num_workers: 0  # number of workers to generate batches with
num_classes: 4  # number of classes in training data which network should predict
device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# define various paths
bundle_root: .  # root directory of the bundle
ckpt_path: $@bundle_root + '/models/model.pt'  # checkpoint to load before starting
dataset_dir: $@bundle_root + '/data/test_data'  # where data is coming from

# network definition, this could be parameterised by pre-defined values or on the command line
network_def:
  _target_: DenseNet121
  spatial_dims: 2
  in_channels: 1
  out_channels: '@num_classes'
network: $@network_def.to(@device)

# list all niftis in the input directory
test_json: "$@bundle_root+'/data/test_samples.json'"
test_fp: "$open(@test_json,'r', encoding='utf8')"
# load json file
test_dict: "$json.load(@test_fp)"

# these transforms are used for inference to load and regularise inputs
transforms:
- _target_: LoadImaged
  keys: '@image'
- _target_: EnsureChannelFirstd
  keys: '@image'
- _target_: ScaleIntensityd
  keys: '@image'

preprocessing:
  _target_: Compose
  transforms: $@transforms

dataset:
  _target_: Dataset
  data: '@test_dict'
  transform: '@preprocessing'

dataloader:
  _target_: ThreadDataLoader  # generate data ansynchronously from inference
  dataset: '@dataset'
  batch_size: '@batch_size'
  num_workers: '@num_workers'

# should be replaced with other inferer types if training process is different for your network
inferer:
  _target_: SimpleInferer

# transform to apply to data from network to be suitable for validation
postprocessing:
  _target_: Compose
  transforms:
  - _target_: Activationsd
    keys: '@pred'
    softmax: true
  - _target_: AsDiscreted
    keys: ['@pred', '@label']
    argmax: [true, false]
    to_onehot: '@num_classes'
  - _target_: ToTensord
    keys: ['@pred', '@label']
    device: '@device'

# inference handlers to load checkpoint, gather statistics
val_handlers:
- _target_: CheckpointLoader
  _disabled_: $not os.path.exists(@ckpt_path)
  load_path: '@ckpt_path'
  load_dict:
    model: '@network'
- _target_: StatsHandler
  name: null  # use engine.logger as the Logger object to log to
  output_transform: '$lambda x: None'

# engine for running inference, ties together objects defined above and has metric definitions
evaluator:
  _target_: SupervisedEvaluator
  device: '@device'
  val_data_loader: '@dataloader'
  network: '@network'
  inferer: '@inferer'
  postprocessing: '@postprocessing'
  key_val_metric:
    val_accuracy:
      _target_: ignite.metrics.Accuracy
      output_transform: $monai.handlers.from_engine([@pred, @label])
  additional_metrics:
    val_f1:  # can have other metrics
      _target_: ConfusionMatrix
      metric_name: 'f1 score'
      output_transform: $monai.handlers.from_engine([@pred, @label])
  val_handlers: '@val_handlers'

initialize:
- "$setattr(torch.backends.cudnn, 'benchmark', True)"
run:
- "[email protected]()"