HirraA commited on
Commit
006869b
·
verified ·
1 Parent(s): c44d232

Upload 19 files

Browse files
base/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .base_data_loader import *
2
+ from .base_dataset import *
3
+ from .base_model import *
4
+ from .base_trainer import *
base/base_data_loader.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from torch.utils.data import DataLoader
3
+ from torch.utils.data.dataloader import default_collate
4
+ from torch.utils.data.sampler import SubsetRandomSampler
5
+
6
+
7
+ class BaseDataLoader(DataLoader):
8
+ """
9
+ Base class for all data loaders
10
+ """
11
+
12
+ def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
13
+ self.validation_split = validation_split
14
+ self.shuffle = shuffle
15
+
16
+ self.batch_idx = 0
17
+ self.n_samples = len(dataset)
18
+
19
+ self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
20
+
21
+ self.init_kwargs = {
22
+ 'dataset': dataset,
23
+ 'batch_size': batch_size,
24
+ 'shuffle': self.shuffle,
25
+ 'collate_fn': collate_fn,
26
+ 'num_workers': num_workers
27
+ }
28
+ super().__init__(sampler=self.sampler, **self.init_kwargs)
29
+
30
+ def _split_sampler(self, split):
31
+ if split == 0.0:
32
+ return None, None
33
+
34
+ idx_full = np.arange(self.n_samples)
35
+
36
+ np.random.seed(0)
37
+ np.random.shuffle(idx_full)
38
+
39
+ if isinstance(split, int):
40
+ assert split > 0
41
+ assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
42
+ len_valid = split
43
+ else:
44
+ len_valid = int(self.n_samples * split)
45
+
46
+ valid_idx = idx_full[0:len_valid]
47
+ train_idx = np.delete(idx_full, np.arange(0, len_valid))
48
+
49
+ train_sampler = SubsetRandomSampler(train_idx)
50
+ valid_sampler = SubsetRandomSampler(valid_idx)
51
+
52
+ # turn off shuffle option which is mutually exclusive with sampler
53
+ self.shuffle = False
54
+ self.n_samples = len(train_idx)
55
+
56
+ return train_sampler, valid_sampler
57
+
58
+ def split_validation(self):
59
+ if self.valid_sampler is None:
60
+ return None
61
+ else:
62
+ return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
base/base_dataset.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ from pathlib import Path
3
+ from typing import Any, Callable, Optional
4
+ from torchvision.datasets import VisionDataset
5
+
6
+
7
+ class BaseDataset(VisionDataset):
8
+ def __init__(
9
+ self,
10
+ root: str,
11
+ loader: Callable[[str], Any],
12
+ transforms: Optional[Callable] = None,
13
+ transform: Optional[Callable] = None,
14
+ target_transform: Optional[Callable] = None,
15
+ train: bool = True
16
+ ) -> None:
17
+ super().__init__(root, transforms, transform, target_transform)
18
+
19
+ self.root_path = Path(root)
20
+ self.loader = loader
21
+
22
+ mode = 'train' if train else 'test'
23
+ self.data = sorted(glob.glob(f'{mode}/images/*.jpg', root_dir=root))
24
+ self.masks = sorted(glob.glob(f'{mode}/masks/*.png', root_dir=root))
25
+
26
+ def __getitem__(self, index: int) -> Any:
27
+ img_path, mask_path = self.data[index], self.masks[index]
28
+ img_path, mask_path = self.root_path / img_path, self.root_path / mask_path
29
+
30
+ img, mask = self.loader(img_path), self.loader(mask_path)
31
+ img, mask = self.transforms(img, mask)
32
+ return img, mask.squeeze(dim=0).bool().float()
33
+
34
+ def __len__(self) -> int:
35
+ return len(self.data)
base/base_model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import numpy as np
3
+ import torch.nn as nn
4
+
5
+
6
+ class BaseModel(nn.Module):
7
+ """
8
+ Base class for all models
9
+ """
10
+ @abstractmethod
11
+ def forward(self, *inputs):
12
+ """
13
+ Forward pass logic
14
+
15
+ :return: Model output
16
+ """
17
+ raise NotImplementedError
18
+
19
+ def __str__(self):
20
+ """
21
+ Model prints with number of trainable parameters
22
+ """
23
+ model_parameters = filter(lambda p: p.requires_grad, self.parameters())
24
+ params = sum([np.prod(p.size()) for p in model_parameters])
25
+ return super().__str__() + '\nTrainable parameters: {}'.format(params)
base/base_trainer.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import torch
3
+ from logger import TensorboardWriter
4
+ from numpy import inf
5
+
6
+
7
+ class BaseTrainer:
8
+ """
9
+ Base class for all trainers
10
+ """
11
+
12
+ def __init__(self, model, criterion, metric_ftns, optimizer, config):
13
+ self.config = config
14
+ self.logger = config.get_logger('trainer', config['trainer']['verbosity'])
15
+
16
+ self.model = model
17
+ self.criterion = criterion
18
+ self.metric_ftns = metric_ftns
19
+ self.optimizer = optimizer
20
+
21
+ cfg_trainer = config['trainer']
22
+ self.epochs = cfg_trainer['epochs']
23
+ self.save_period = cfg_trainer['save_period']
24
+ self.monitor = cfg_trainer.get('monitor', 'off')
25
+
26
+ # configuration to monitor model performance and save best
27
+ if self.monitor == 'off':
28
+ self.mnt_mode = 'off'
29
+ self.mnt_best = 0
30
+ else:
31
+ self.mnt_mode, self.mnt_metric = self.monitor.split()
32
+ assert self.mnt_mode in ['min', 'max']
33
+
34
+ self.mnt_best = inf if self.mnt_mode == 'min' else -inf
35
+ self.early_stop = cfg_trainer.get('early_stop', inf)
36
+ if self.early_stop <= 0:
37
+ self.early_stop = inf
38
+
39
+ self.start_epoch = 1
40
+
41
+ self.checkpoint_dir = config.save_dir
42
+
43
+ # setup visualization writer instance
44
+ self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])
45
+
46
+ if config.resume is not None:
47
+ self._resume_checkpoint(config.resume)
48
+
49
+ @abstractmethod
50
+ def _train_epoch(self, epoch):
51
+ """
52
+ Training logic for an epoch
53
+
54
+ :param epoch: Current epoch number
55
+ """
56
+ raise NotImplementedError
57
+
58
+ def train(self):
59
+ """
60
+ Full training logic
61
+ """
62
+ not_improved_count = 0
63
+ for epoch in range(self.start_epoch, self.epochs + 1):
64
+ result = self._train_epoch(epoch)
65
+
66
+ # save logged informations into log dict
67
+ log = {'epoch': epoch}
68
+ log.update(result)
69
+
70
+ # print logged informations to the screen
71
+ for key, value in log.items():
72
+ self.logger.info(' {:15s}: {}'.format(str(key), value))
73
+
74
+ # evaluate model performance according to configured metric, save best checkpoint as model_best
75
+ best = False
76
+ if self.mnt_mode != 'off':
77
+ try:
78
+ # check whether model performance improved or not, according to specified metric(mnt_metric)
79
+ improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
80
+ (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
81
+ except KeyError:
82
+ self.logger.warning("Warning: Metric '{}' is not found. "
83
+ "Model performance monitoring is disabled.".format(self.mnt_metric))
84
+ self.mnt_mode = 'off'
85
+ improved = False
86
+
87
+ if improved:
88
+ self.mnt_best = log[self.mnt_metric]
89
+ not_improved_count = 0
90
+ best = True
91
+ else:
92
+ not_improved_count += 1
93
+
94
+ if not_improved_count > self.early_stop:
95
+ self.logger.info("Validation performance didn\'t improve for {} epochs. "
96
+ "Training stops.".format(self.early_stop))
97
+ break
98
+
99
+ if epoch % self.save_period == 0:
100
+ self._save_checkpoint(epoch, save_best=best)
101
+
102
+ def _save_checkpoint(self, epoch, save_best=False):
103
+ """
104
+ Saving checkpoints
105
+
106
+ :param epoch: current epoch number
107
+ :param log: logging information of the epoch
108
+ :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
109
+ """
110
+ arch = type(self.model).__name__
111
+ state = {
112
+ 'arch': arch,
113
+ 'epoch': epoch,
114
+ 'state_dict': self.model.state_dict(),
115
+ 'optimizer': self.optimizer.state_dict(),
116
+ 'monitor_best': self.mnt_best,
117
+ 'config': self.config
118
+ }
119
+ filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch))
120
+ torch.save(state, filename)
121
+ self.logger.info("Saving checkpoint: {} ...".format(filename))
122
+ if save_best:
123
+ best_path = str(self.checkpoint_dir / 'model_best.pth')
124
+ torch.save(state, best_path)
125
+ self.logger.info("Saving current best: model_best.pth ...")
126
+
127
+ def _resume_checkpoint(self, resume_path):
128
+ """
129
+ Resume from saved checkpoints
130
+
131
+ :param resume_path: Checkpoint path to be resumed
132
+ """
133
+ resume_path = str(resume_path)
134
+ self.logger.info("Loading checkpoint: {} ...".format(resume_path))
135
+ checkpoint = torch.load(resume_path)
136
+ self.start_epoch = checkpoint['epoch'] + 1
137
+ self.mnt_best = checkpoint['monitor_best']
138
+
139
+ # load architecture params from checkpoint.
140
+ if checkpoint['config']['arch'] != self.config['arch']:
141
+ self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
142
+ "checkpoint. This may yield an exception while state_dict is being loaded.")
143
+ self.model.load_state_dict(checkpoint['state_dict'])
144
+
145
+ # load optimizer state from checkpoint only when optimizer type is not changed.
146
+ if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
147
+ self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
148
+ "Optimizer parameters not being resumed.")
149
+ else:
150
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
151
+
152
+ self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
data_loader/data_loaders.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from base import BaseDataLoader
2
+ from dataset.datasets import PatchedDataset
3
+ from torchvision import transforms
4
+ from torch.utils.data.sampler import SequentialSampler
5
+
6
+
7
+ class PatchedDataLoader(BaseDataLoader):
8
+ def __init__(
9
+ self,
10
+ data_dir,
11
+ patch_size,
12
+ batch_size,
13
+ patch_stride=None,
14
+ preds=None,
15
+ target_dist=0.0,
16
+ shuffle=True,
17
+ validation_split=0.0,
18
+ num_workers=1,
19
+ training=True
20
+ ):
21
+ trsfm = transforms.Compose([
22
+ transforms.ToTensor(),
23
+ transforms.Normalize((0.3551, 0.4698, 0.2261),
24
+ (0.1966, 0.1988, 0.1761))
25
+ ])
26
+ target_trsfm = transforms.Compose([
27
+ transforms.ToTensor(),
28
+ ])
29
+ rand_trsfm = transforms.RandomApply([
30
+ transforms.RandomVerticalFlip(),
31
+ transforms.RandomHorizontalFlip()
32
+ ])
33
+ self.data_dir = data_dir
34
+ self.dataset = PatchedDataset(
35
+ self.data_dir,
36
+ patch_size,
37
+ patch_stride=patch_stride,
38
+ preds=preds,
39
+ target_dist=target_dist,
40
+ transform=trsfm,
41
+ target_transform=target_trsfm,
42
+ rand_transform=rand_trsfm if training and shuffle else None,
43
+ train=training,
44
+ late_init=True
45
+ )
46
+ super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)
47
+
48
+ def _split_sampler(self, split):
49
+ train_sampler, valid_sampler = super()._split_sampler(split)
50
+
51
+ if valid_sampler is not None:
52
+ self.dataset.make_dataset(valid_indices=valid_sampler.indices)
53
+ else:
54
+ self.dataset.make_dataset()
55
+
56
+ train_idx, valid_idx = [], []
57
+ for patch in self.dataset.patches:
58
+ if valid_sampler is not None and patch.idx in valid_sampler.indices:
59
+ valid_idx.append(self.dataset.patches.index(patch))
60
+ else:
61
+ train_idx.append(self.dataset.patches.index(patch))
62
+
63
+ if valid_sampler is not None:
64
+ train_sampler.indices, valid_sampler.indices = train_idx, valid_idx
65
+ else:
66
+ train_sampler = SequentialSampler(train_idx)
67
+
68
+ # turn off shuffle option which is mutually exclusive with sampler
69
+ self.shuffle = False
70
+ self.n_samples = len(train_idx)
71
+
72
+ return train_sampler, valid_sampler
73
+
74
+ def update_dataset(self, preds):
75
+ self.dataset.preds = preds
76
+ self.dataset.patches.clear()
77
+ self.n_samples = len(self.dataset)
78
+
79
+ train_sampler, valid_sampler = self._split_sampler(
80
+ self.validation_split)
81
+ if valid_sampler is not None:
82
+ self.valid_sampler.indices = valid_sampler.indices
83
+ self.sampler.indices = train_sampler.indices
dataset/datasets.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Optional
2
+ import torch.nn.functional as F
3
+ from base import BaseDataset
4
+ from utils.util import TransformMultiple, pil_loader
5
+ from dataset.patches import Patches
6
+
7
+
8
+ class PatchedDataset(BaseDataset):
9
+ def __init__(
10
+ self,
11
+ root: str,
12
+ patch_size: int,
13
+ patch_stride: int = None,
14
+ preds: list = None,
15
+ target_dist: float = 0.0,
16
+ transforms: Optional[Callable] = None,
17
+ transform: Optional[Callable] = None,
18
+ target_transform: Optional[Callable] = None,
19
+ rand_transform: Optional[Callable] = None,
20
+ train: bool = True,
21
+ late_init: bool = False
22
+ ) -> None:
23
+ super().__init__(root, pil_loader, transforms, transform, target_transform, train)
24
+ self.patches = Patches(patch_size, patch_stride)
25
+ self.preds = preds
26
+ self.target_dist = target_dist * patch_size ** 2
27
+ self.rand_transform = TransformMultiple(rand_transform)
28
+ if not late_init:
29
+ self.make_dataset()
30
+
31
+ def make_dataset(self, valid_indices=[]):
32
+ for idx in range(super().__len__()):
33
+ _, mask = super().__getitem__(idx)
34
+ if self.preds is not None:
35
+ mask = self._union_mask(mask, self.preds[idx])
36
+
37
+ if idx not in valid_indices:
38
+ self.patches.create(idx, mask, cond_fn=self._dist_fn
39
+ if self.target_dist != 0.0 else None)
40
+ else:
41
+ self.patches.create(idx, mask, no_overlap=True)
42
+
43
+ def __getitem__(self, index: int) -> Any:
44
+ patch = self.patches[index]
45
+ img, mask = super().__getitem__(patch.idx)
46
+
47
+ img_patch = self.patches.get_patch(img, patch)
48
+ mask_patch = self.patches.get_patch(mask, patch)
49
+ img_patch, mask_patch = self.rand_transform(
50
+ (img_patch, mask_patch.unsqueeze(dim=0)))
51
+ return img_patch, mask_patch.squeeze(dim=0)
52
+
53
+ def _union_mask(self, mask, pred):
54
+ pred = F.pad(
55
+ pred, (0, mask.shape[1] - pred.shape[1], 0, mask.shape[0] - pred.shape[0]))
56
+ return (mask + pred) - (mask * pred)
57
+
58
+ def _dist_fn(self, mask, patch):
59
+ data = self.patches.get_patch(mask, patch)
60
+ return data.count_nonzero() > self.target_dist
dataset/patches.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import UserList
2
+ from torchvision.utils import make_grid
3
+
4
+
5
+ class Patch():
6
+ def __init__(self, idx, x, y) -> None:
7
+ self.idx = idx
8
+ self.x, self.y = x, y
9
+ self.data = None
10
+
11
+ def __eq__(self, __o: object) -> bool:
12
+ return self.idx == __o.idx and self.x == __o.x and self.y == __o.y
13
+
14
+
15
+ class Patches(UserList):
16
+ def __init__(self, size, stride=None):
17
+ super().__init__()
18
+ self.size = size
19
+ self.stride = stride if stride is not None else size
20
+
21
+ def create(self, index, data, cond_fn=None, no_overlap=False):
22
+ stride = self.size if no_overlap else self.stride
23
+ for x in range(0, data.size(-2) - self.size + 1, stride):
24
+ for y in range(0, data.size(-1) - self.size + 1, stride):
25
+ patch = Patch(index, x, y)
26
+ if cond_fn is None or cond_fn(data, patch):
27
+ self.append(patch)
28
+
29
+ def get_patch(self, data, patch: Patch):
30
+ assert data.ndim in {2, 3}, 'only 2-D and 3-D Tensors are supported.'
31
+ _data = data.unsqueeze(dim=0) if data.ndim == 2 else data
32
+ data_patch = _data[:, patch.x:patch.x + self.size,
33
+ patch.y:patch.y + self.size]
34
+ return data_patch.squeeze(dim=0) if data.ndim == 2 else data_patch
35
+
36
+ def store_data(self, indices, data):
37
+ for idx in range(len(indices)):
38
+ self[indices[idx]].data = [data[i][idx] for i in range(len(data))]
39
+
40
+ def retrieve_data(self, indices):
41
+ return [[self[idx].data[i] for idx in indices] for i in range(len(self[indices[0]].data))]
42
+
43
+ def combine(self, index: int, data_idx: int):
44
+ indices = [self.index(patch) for patch in self if patch.idx == index]
45
+ indices.sort(key=lambda idx: (self[idx].x, self[idx].y))
46
+
47
+ data = self.retrieve_data(indices)
48
+ nrow = sum([self[idx].x == 0 for idx in indices])
49
+ return make_grid(data[data_idx], nrow, padding=0)
logger/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .logger import *
2
+ from .visualization import *
logger/logger.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import logging.config
3
+ from pathlib import Path
4
+
5
+ from utils import read_json
6
+
7
+
8
+ def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO):
9
+ """
10
+ Setup logging configuration
11
+ """
12
+ log_config = Path(log_config)
13
+ if log_config.is_file():
14
+ config = read_json(log_config)
15
+ # modify logging paths based on run config
16
+ for _, handler in config['handlers'].items():
17
+ if 'filename' in handler:
18
+ handler['filename'] = str(save_dir / handler['filename'])
19
+
20
+ logging.config.dictConfig(config)
21
+ else:
22
+ print("Warning: logging configuration file is not found in {}.".format(log_config))
23
+ logging.basicConfig(level=default_level)
logger/logger_config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 1,
3
+ "disable_existing_loggers": false,
4
+ "formatters": {
5
+ "simple": {
6
+ "format": "%(message)s"
7
+ },
8
+ "datetime": {
9
+ "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
10
+ }
11
+ },
12
+ "handlers": {
13
+ "console": {
14
+ "class": "logging.StreamHandler",
15
+ "level": "DEBUG",
16
+ "formatter": "simple",
17
+ "stream": "ext://sys.stdout"
18
+ },
19
+ "info_file_handler": {
20
+ "class": "logging.handlers.RotatingFileHandler",
21
+ "level": "INFO",
22
+ "formatter": "datetime",
23
+ "filename": "info.log",
24
+ "maxBytes": 10485760,
25
+ "backupCount": 20,
26
+ "encoding": "utf8"
27
+ }
28
+ },
29
+ "root": {
30
+ "level": "INFO",
31
+ "handlers": [
32
+ "console",
33
+ "info_file_handler"
34
+ ]
35
+ }
36
+ }
logger/visualization.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from datetime import datetime
3
+
4
+
5
+ class TensorboardWriter():
6
+ def __init__(self, log_dir, logger, enabled):
7
+ self.writer = None
8
+ self.selected_module = ""
9
+
10
+ if enabled:
11
+ log_dir = str(log_dir)
12
+
13
+ # Retrieve vizualization writer.
14
+ succeeded = False
15
+ for module in ["torch.utils.tensorboard", "tensorboardX"]:
16
+ try:
17
+ self.writer = importlib.import_module(module).SummaryWriter(log_dir)
18
+ succeeded = True
19
+ break
20
+ except ImportError:
21
+ succeeded = False
22
+ self.selected_module = module
23
+
24
+ if not succeeded:
25
+ message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
26
+ "this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \
27
+ "version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file."
28
+ logger.warning(message)
29
+
30
+ self.step = 0
31
+ self.mode = ''
32
+ self.step_tracker = {}
33
+
34
+ self.tb_writer_ftns = {
35
+ 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
36
+ 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
37
+ }
38
+ self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
39
+ self.timer = datetime.now()
40
+
41
+ def set_step(self, step, mode='train'):
42
+ self.mode = mode
43
+ self.step = step
44
+ if step == 0:
45
+ self.timer = datetime.now()
46
+ else:
47
+ duration = datetime.now() - self.timer
48
+ self.add_scalar('steps_per_sec', 1 / duration.total_seconds())
49
+ self.timer = datetime.now()
50
+
51
+ def next(self, mode='train'):
52
+ step = self.step_tracker[mode] = self.step_tracker.get(mode, 0) + 1
53
+ self.set_step(step, mode=mode)
54
+
55
+ def __getattr__(self, name):
56
+ """
57
+ If visualization is configured to use:
58
+ return add_data() methods of tensorboard with additional information (step, tag) added.
59
+ Otherwise:
60
+ return a blank function handle that does nothing
61
+ """
62
+ if name in self.tb_writer_ftns:
63
+ add_data = getattr(self.writer, name, None)
64
+
65
+ def wrapper(tag, data, *args, **kwargs):
66
+ if add_data is not None:
67
+ # add mode(train/valid) tag
68
+ if name not in self.tag_mode_exceptions:
69
+ tag = '{}/{}'.format(tag, self.mode)
70
+ add_data(tag, data, self.step, *args, **kwargs)
71
+ return wrapper
72
+ else:
73
+ # default action for returning methods defined in this class, set_step() for instance.
74
+ try:
75
+ attr = object.__getattr__(name)
76
+ except AttributeError:
77
+ raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
78
+ return attr
model/loss.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def focal_loss(inputs, targets, alpha=0.5, gamma=2, reduction='mean'):
6
+ logpt = F.cross_entropy(inputs, targets.long(), reduction='none')
7
+ pt = torch.exp(-logpt)
8
+ focal_loss = (1 - pt) ** gamma * logpt
9
+ alpha_weight = alpha * targets + (1 - alpha) * (1 - targets)
10
+ focal_loss = alpha_weight * focal_loss
11
+
12
+ if reduction == 'mean':
13
+ return torch.mean(focal_loss)
14
+ elif reduction == 'sum':
15
+ return torch.sum(focal_loss)
16
+ else:
17
+ return focal_loss
18
+
19
+
20
+ def dice_loss(inputs, targets, epsilon=1e-7):
21
+ targets_one_hot = torch.nn.functional.one_hot(targets.long(), num_classes=inputs.shape[1])
22
+ targets_one_hot = targets_one_hot.permute(0, 3, 1, 2).float()
23
+ inputs = F.softmax(inputs, dim=1)
24
+ targets_one_hot = targets_one_hot.type(inputs.type())
25
+ numerator = 2 * (inputs * targets_one_hot).sum(dim=(2,3))
26
+ denominator = inputs.sum(dim=(2,3)) + targets_one_hot.sum(dim=(2,3))
27
+ dice_coefficient = numerator / (denominator + epsilon)
28
+ return 1 - dice_coefficient.mean()
model/metric.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from sklearn.metrics import f1_score as f1
3
+ from sklearn.metrics import precision_score, recall_score
4
+
5
+
6
+ def precision(output, target):
7
+ with torch.no_grad():
8
+ pred = torch.argmax(output, dim=1)
9
+ assert pred.shape[0] == len(target)
10
+ return precision_score(target.view(-1).cpu(), pred.view(-1).cpu())
11
+
12
+
13
+ def recall(output, target):
14
+ with torch.no_grad():
15
+ pred = torch.argmax(output, dim=1)
16
+ assert pred.shape[0] == len(target)
17
+ return recall_score(target.view(-1).cpu(), pred.view(-1).cpu())
18
+
19
+
20
+ def f1_score(output, target):
21
+ with torch.no_grad():
22
+ pred = torch.argmax(output, dim=1)
23
+ assert pred.shape[0] == len(target)
24
+ return f1(target.view(-1).cpu(), pred.view(-1).cpu())
model/model.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from base import BaseModel
5
+
6
+
7
+ class UNet(BaseModel):
8
+ def __init__(self, n_channels, n_classes, bilinear=False):
9
+ super(UNet, self).__init__()
10
+ self.n_channels = n_channels
11
+ self.n_classes = n_classes
12
+ self.bilinear = bilinear
13
+
14
+ self.inc = DoubleConv(n_channels, 64)
15
+ self.down1 = Down(64, 128)
16
+ self.down2 = Down(128, 256)
17
+ self.down3 = Down(256, 512)
18
+ factor = 2 if bilinear else 1
19
+ self.down4 = Down(512, 1024 // factor)
20
+ self.up1 = Up(1024, 512 // factor, bilinear)
21
+ self.up2 = Up(512, 256 // factor, bilinear)
22
+ self.up3 = Up(256, 128 // factor, bilinear)
23
+ self.up4 = Up(128, 64, bilinear)
24
+ self.outc = OutConv(64, n_classes)
25
+
26
+ def forward(self, x):
27
+ x1 = self.inc(x)
28
+ x2 = self.down1(x1)
29
+ x3 = self.down2(x2)
30
+ x4 = self.down3(x3)
31
+ x5 = self.down4(x4)
32
+ x = self.up1(x5, x4)
33
+ x = self.up2(x, x3)
34
+ x = self.up3(x, x2)
35
+ x = self.up4(x, x1)
36
+ logits = self.outc(x)
37
+ return F.log_softmax(logits, dim=1)
38
+
39
+
40
+ class DoubleConv(BaseModel):
41
+ """(convolution => [BN] => ReLU) * 2"""
42
+
43
+ def __init__(self, in_channels, out_channels, mid_channels=None):
44
+ super().__init__()
45
+ if not mid_channels:
46
+ mid_channels = out_channels
47
+ self.double_conv = nn.Sequential(
48
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
49
+ nn.BatchNorm2d(mid_channels),
50
+ nn.ReLU(inplace=True),
51
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
52
+ nn.BatchNorm2d(out_channels),
53
+ nn.ReLU(inplace=True)
54
+ )
55
+
56
+ def forward(self, x):
57
+ return self.double_conv(x)
58
+
59
+
60
+ class Down(BaseModel):
61
+ """Downscaling with maxpool then double conv"""
62
+
63
+ def __init__(self, in_channels, out_channels):
64
+ super().__init__()
65
+ self.maxpool_conv = nn.Sequential(
66
+ nn.MaxPool2d(kernel_size=8, stride=2, padding=3),
67
+ DoubleConv(in_channels, out_channels)
68
+ )
69
+
70
+ def forward(self, x):
71
+ return self.maxpool_conv(x)
72
+
73
+
74
+ class Up(BaseModel):
75
+ """Upscaling then double conv"""
76
+
77
+ def __init__(self, in_channels, out_channels, bilinear=True):
78
+ super().__init__()
79
+
80
+ # if bilinear, use the normal convolutions to reduce the number of channels
81
+ if bilinear:
82
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
83
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
84
+ else:
85
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
86
+ self.conv = DoubleConv(in_channels, out_channels)
87
+
88
+ def forward(self, x1, x2):
89
+ x1 = self.up(x1)
90
+ # input is CHW
91
+ diffY = x2.size()[2] - x1.size()[2]
92
+ diffX = x2.size()[3] - x1.size()[3]
93
+
94
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
95
+ diffY // 2, diffY - diffY // 2])
96
+ x = torch.cat([x2, x1], dim=1)
97
+ return self.conv(x)
98
+
99
+
100
+ class OutConv(BaseModel):
101
+ def __init__(self, in_channels, out_channels):
102
+ super(OutConv, self).__init__()
103
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
104
+
105
+ def forward(self, x):
106
+ return self.conv(x)
trainer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .trainer import *
trainer/trainer.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from base import BaseTrainer
4
+ from torchvision.utils import make_grid
5
+ from utils import MetricTracker
6
+
7
+
8
+ class Trainer(BaseTrainer):
9
+ """
10
+ Trainer class
11
+ """
12
+
13
+ def __init__(self, model, criterion, metric_ftns, optimizer, config, device,
14
+ data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None):
15
+ super().__init__(model, criterion, metric_ftns, optimizer, config)
16
+ self.config = config
17
+ self.device = device
18
+ self.data_loader = data_loader
19
+ self.len_epoch = len(self.data_loader)
20
+ self.valid_data_loader = valid_data_loader
21
+ self.do_validation = self.valid_data_loader is not None
22
+ self.lr_scheduler = lr_scheduler
23
+ self.log_step = int(np.sqrt(data_loader.batch_size))
24
+ self.adaptive_step = config['trainer']['adaptive_step']
25
+
26
+ self.train_metrics = MetricTracker(
27
+ 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
28
+ self.valid_metrics = MetricTracker(
29
+ 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
30
+
31
+ def _train_epoch(self, epoch):
32
+ """
33
+ Training logic for an epoch
34
+
35
+ :param epoch: Integer, current training epoch.
36
+ :return: A log that contains average loss and metric in this epoch.
37
+ """
38
+ if epoch > self.adaptive_step and epoch % self.adaptive_step == 1:
39
+ dataset = self.data_loader.inference.dataset
40
+ self.model.eval()
41
+ with torch.no_grad():
42
+ for batch_idx, (data, target) in enumerate(self.data_loader.inference):
43
+ data, target = data.to(self.device), target.to(self.device)
44
+ output = self.model(data)
45
+
46
+ batch_size = self.data_loader.inference.batch_size
47
+ patch_idx = torch.arange(
48
+ batch_size * batch_idx, batch_size * batch_idx + data.shape[0])
49
+ pred = torch.argmax(output, dim=1)
50
+ dataset.patches.store_data(patch_idx, [pred.unsqueeze(1)])
51
+
52
+ preds = [dataset.patches.combine(idx, data_idx=0)[0].cpu()
53
+ for idx in range(len(dataset.data))]
54
+
55
+ self.data_loader.update_dataset(preds)
56
+ self.len_epoch = len(self.data_loader)
57
+
58
+ self.model.train()
59
+ self.train_metrics.reset()
60
+ for batch_idx, (data, target) in enumerate(self.data_loader):
61
+ data, target = data.to(self.device), target.to(self.device)
62
+
63
+ self.optimizer.zero_grad()
64
+ output = self.model(data)
65
+ loss = self.criterion(output, target)
66
+ loss.backward()
67
+ self.optimizer.step()
68
+
69
+ self.train_metrics.update('loss', loss.item())
70
+ for met in self.metric_ftns:
71
+ self.train_metrics.update(met.__name__, met(output, target))
72
+
73
+ if batch_idx % self.log_step == 0:
74
+ self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
75
+ epoch,
76
+ self._progress(batch_idx),
77
+ loss.item()))
78
+ self.writer.add_image('input', make_grid(
79
+ data.cpu(), nrow=8, normalize=True))
80
+
81
+ if batch_idx == self.len_epoch:
82
+ break
83
+
84
+ self.writer.next()
85
+ self.train_metrics.add_scalers()
86
+ log = self.train_metrics.result()
87
+
88
+ if self.do_validation:
89
+ val_log = self._valid_epoch(epoch)
90
+ log.update(**{'val_' + k: v for k, v in val_log.items()})
91
+
92
+ if self.lr_scheduler is not None:
93
+ self.lr_scheduler.step()
94
+ return log
95
+
96
+ def _valid_epoch(self, epoch):
97
+ """
98
+ Validate after training an epoch
99
+
100
+ :param epoch: Integer, current training epoch.
101
+ :return: A log that contains information about validation
102
+ """
103
+ self.model.eval()
104
+ self.valid_metrics.reset()
105
+ with torch.no_grad():
106
+ for batch_idx, (data, target) in enumerate(self.valid_data_loader):
107
+ data, target = data.to(self.device), target.to(self.device)
108
+
109
+ output = self.model(data)
110
+ loss = self.criterion(output, target)
111
+
112
+ self.valid_metrics.update('loss', loss.item())
113
+ for met in self.metric_ftns:
114
+ self.valid_metrics.update(
115
+ met.__name__, met(output, target))
116
+ self.writer.add_image('input', make_grid(
117
+ data.cpu(), nrow=8, normalize=True))
118
+
119
+ # add histogram of model parameters to the tensorboard
120
+ for name, p in self.model.named_parameters():
121
+ self.writer.add_histogram(name, p, bins='auto')
122
+
123
+ self.writer.next('valid')
124
+ self.valid_metrics.add_scalers()
125
+ return self.valid_metrics.result()
126
+
127
+ def _progress(self, batch_idx):
128
+ base = '[{}/{} ({:.0f}%)]'
129
+ if hasattr(self.data_loader, 'n_samples'):
130
+ current = batch_idx * self.data_loader.batch_size
131
+ total = self.data_loader.n_samples
132
+ else:
133
+ current = batch_idx
134
+ total = self.len_epoch
135
+ return base.format(current, total, 100.0 * current / total)
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .util import *
utils/util.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from collections import OrderedDict
3
+ from itertools import repeat
4
+ from pathlib import Path
5
+ from typing import Any, Callable, List, Optional, Tuple
6
+ import pandas as pd
7
+ import PIL.Image as Image
8
+ import torch
9
+
10
+
11
+ def read_json(fname):
12
+ fname = Path(fname)
13
+ with fname.open('rt') as handle:
14
+ return json.load(handle, object_hook=OrderedDict)
15
+
16
+
17
+ def write_json(content, fname):
18
+ fname = Path(fname)
19
+ with fname.open('wt') as handle:
20
+ json.dump(content, handle, indent=4, sort_keys=False)
21
+
22
+
23
+ def pil_loader(fname: str) -> Image.Image:
24
+ return Image.open(fname)
25
+
26
+
27
+ def prepare_device(n_gpu_use):
28
+ """
29
+ setup GPU device if available. get gpu device indices which are used for DataParallel
30
+ """
31
+ n_gpu = torch.cuda.device_count()
32
+ if n_gpu_use > 0 and n_gpu == 0:
33
+ print("Warning: There\'s no GPU available on this machine,"
34
+ "training will be performed on CPU.")
35
+ n_gpu_use = 0
36
+ if n_gpu_use > n_gpu:
37
+ print(f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are "
38
+ "available on this machine.")
39
+ n_gpu_use = n_gpu
40
+ device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
41
+ list_ids = list(range(n_gpu_use))
42
+ return device, list_ids
43
+
44
+
45
+ class TransformMultiple:
46
+ def __init__(self, transform: Optional[Callable] = None) -> None:
47
+ self.transform = transform
48
+
49
+ def __call__(self, data: Any) -> Tuple:
50
+ if self.transform is not None:
51
+ cat_data = torch.cat(data)
52
+ cat_data = self.transform(cat_data)
53
+ return torch.split(cat_data, [x.size()[0] for x in data])
54
+ return data
55
+
56
+ def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
57
+ lines = transform.__repr__().splitlines()
58
+ return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
59
+
60
+ def __repr__(self) -> str:
61
+ body = [self.__class__.__name__]
62
+ if self.transform is not None:
63
+ body += self._format_transform_repr(self.transform, "Transform: ")
64
+ return "\n".join(body)
65
+
66
+
67
+ class MetricTracker:
68
+ def __init__(self, *keys, writer=None):
69
+ self.writer = writer
70
+ self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average'])
71
+ self.reset()
72
+
73
+ def reset(self):
74
+ for col in self._data.columns:
75
+ self._data[col].values[:] = 0
76
+
77
+ def update(self, key, value, n=1):
78
+ self._data.total[key] += value * n
79
+ self._data.counts[key] += n
80
+ self._data.average[key] = self._data.total[key] / self._data.counts[key]
81
+
82
+ def add_scalers(self):
83
+ if self.writer is not None:
84
+ for key in self._data.index:
85
+ self.writer.add_scalar(key, self._data.average[key])
86
+
87
+ def avg(self, key):
88
+ return self._data.average[key]
89
+
90
+ def result(self):
91
+ return dict(self._data.average)