File size: 2,792 Bytes
92f0e98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import torch
import numpy as np
from pathlib import Path
from typing import List, Optional, Tuple
from monai.transforms import Compose, LoadImageD, LambdaD, AddChannelD, CenterSpatialCropD
class DuplicateKeyD:
"""
Duplicates a key in the given data dictionary.
"""
def __init__(self, from_key: str, to_key: str):
self.from_key = from_key
self.to_key = to_key
def __call__(self, data):
d = dict(data)
return {**d, self.to_key: d[self.from_key]}
class MergeKeysD:
"""
Stacks images from multiple keys into a single one.
"""
def __init__(self, keys: List[str], to_key: str):
self.keys = keys
self.to_key = to_key
def __call__(self, data):
d = dict(data)
if len(self.keys) == 1 and self.keys[0] == self.to_key:
# nothing to stack or move
return d
arrays_to_stack = [d.pop(key) for key in self.keys]
if isinstance(arrays_to_stack[0], torch.Tensor):
d[self.to_key] = torch.cat(arrays_to_stack, dim=0)
else:
d[self.to_key] = np.concatenate(arrays_to_stack, axis=0)
return d
def get_image_loading_transform(hparams, image_dir: Path, mask_dir: Optional[Path] = None) -> Tuple[Compose, List[str]]:
"""
Loads an image and, depending on the configuration, its corresponding mask.
"""
if hparams.mask == 'none':
return Compose([
LambdaD(keys='image', func=lambda p: image_dir / p),
LoadImageD(keys='image'),
]), ['image']
else:
assert mask_dir is not None
return Compose([
DuplicateKeyD(from_key='image', to_key='mask'),
# load image and mask
LambdaD(keys='image', func=lambda p: image_dir / p),
LoadImageD(keys='image'),
LambdaD(keys='mask', func=lambda p: mask_dir / p),
LoadImageD(keys='mask'),
]), ['image', 'mask']
def get_apply_crop_transform(hparams, loaded_keys: List[str]) -> Tuple[Compose, List[str]]:
"""
Applies a crop to the loaded keys to achieve the desired size, if appropriate for the given configuration.
"""
if hparams.mask == 'crop':
return Compose([]), loaded_keys
else:
return Compose([
AddChannelD(keys=loaded_keys),
CenterSpatialCropD(keys=loaded_keys, roi_size=[hparams.input_size] * hparams.input_dim),
]), loaded_keys
def get_stacking_transform(hparams, loaded_keys: List[str]) -> Tuple[Compose, List[str]]:
"""
Stacks multiple loaded keys (i.e. image and mask) as channels into the single 'image' key. If a single key
is loaded, do nothing.
"""
if len(loaded_keys) > 1:
return Compose([
MergeKeysD(loaded_keys, 'image')
]), ['image']
else:
return Compose([]), loaded_keys |