Upload 19 files
Browse files- base/__init__.py +4 -0
- base/base_data_loader.py +62 -0
- base/base_dataset.py +35 -0
- base/base_model.py +25 -0
- base/base_trainer.py +152 -0
- data_loader/data_loaders.py +83 -0
- dataset/datasets.py +60 -0
- dataset/patches.py +49 -0
- logger/__init__.py +2 -0
- logger/logger.py +23 -0
- logger/logger_config.json +36 -0
- logger/visualization.py +78 -0
- model/loss.py +28 -0
- model/metric.py +24 -0
- model/model.py +106 -0
- trainer/__init__.py +1 -0
- trainer/trainer.py +135 -0
- utils/__init__.py +1 -0
- utils/util.py +91 -0
base/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_data_loader import *
|
2 |
+
from .base_dataset import *
|
3 |
+
from .base_model import *
|
4 |
+
from .base_trainer import *
|
base/base_data_loader.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from torch.utils.data.dataloader import default_collate
|
4 |
+
from torch.utils.data.sampler import SubsetRandomSampler
|
5 |
+
|
6 |
+
|
7 |
+
class BaseDataLoader(DataLoader):
|
8 |
+
"""
|
9 |
+
Base class for all data loaders
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
|
13 |
+
self.validation_split = validation_split
|
14 |
+
self.shuffle = shuffle
|
15 |
+
|
16 |
+
self.batch_idx = 0
|
17 |
+
self.n_samples = len(dataset)
|
18 |
+
|
19 |
+
self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
|
20 |
+
|
21 |
+
self.init_kwargs = {
|
22 |
+
'dataset': dataset,
|
23 |
+
'batch_size': batch_size,
|
24 |
+
'shuffle': self.shuffle,
|
25 |
+
'collate_fn': collate_fn,
|
26 |
+
'num_workers': num_workers
|
27 |
+
}
|
28 |
+
super().__init__(sampler=self.sampler, **self.init_kwargs)
|
29 |
+
|
30 |
+
def _split_sampler(self, split):
|
31 |
+
if split == 0.0:
|
32 |
+
return None, None
|
33 |
+
|
34 |
+
idx_full = np.arange(self.n_samples)
|
35 |
+
|
36 |
+
np.random.seed(0)
|
37 |
+
np.random.shuffle(idx_full)
|
38 |
+
|
39 |
+
if isinstance(split, int):
|
40 |
+
assert split > 0
|
41 |
+
assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
|
42 |
+
len_valid = split
|
43 |
+
else:
|
44 |
+
len_valid = int(self.n_samples * split)
|
45 |
+
|
46 |
+
valid_idx = idx_full[0:len_valid]
|
47 |
+
train_idx = np.delete(idx_full, np.arange(0, len_valid))
|
48 |
+
|
49 |
+
train_sampler = SubsetRandomSampler(train_idx)
|
50 |
+
valid_sampler = SubsetRandomSampler(valid_idx)
|
51 |
+
|
52 |
+
# turn off shuffle option which is mutually exclusive with sampler
|
53 |
+
self.shuffle = False
|
54 |
+
self.n_samples = len(train_idx)
|
55 |
+
|
56 |
+
return train_sampler, valid_sampler
|
57 |
+
|
58 |
+
def split_validation(self):
|
59 |
+
if self.valid_sampler is None:
|
60 |
+
return None
|
61 |
+
else:
|
62 |
+
return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
|
base/base_dataset.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Any, Callable, Optional
|
4 |
+
from torchvision.datasets import VisionDataset
|
5 |
+
|
6 |
+
|
7 |
+
class BaseDataset(VisionDataset):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
root: str,
|
11 |
+
loader: Callable[[str], Any],
|
12 |
+
transforms: Optional[Callable] = None,
|
13 |
+
transform: Optional[Callable] = None,
|
14 |
+
target_transform: Optional[Callable] = None,
|
15 |
+
train: bool = True
|
16 |
+
) -> None:
|
17 |
+
super().__init__(root, transforms, transform, target_transform)
|
18 |
+
|
19 |
+
self.root_path = Path(root)
|
20 |
+
self.loader = loader
|
21 |
+
|
22 |
+
mode = 'train' if train else 'test'
|
23 |
+
self.data = sorted(glob.glob(f'{mode}/images/*.jpg', root_dir=root))
|
24 |
+
self.masks = sorted(glob.glob(f'{mode}/masks/*.png', root_dir=root))
|
25 |
+
|
26 |
+
def __getitem__(self, index: int) -> Any:
|
27 |
+
img_path, mask_path = self.data[index], self.masks[index]
|
28 |
+
img_path, mask_path = self.root_path / img_path, self.root_path / mask_path
|
29 |
+
|
30 |
+
img, mask = self.loader(img_path), self.loader(mask_path)
|
31 |
+
img, mask = self.transforms(img, mask)
|
32 |
+
return img, mask.squeeze(dim=0).bool().float()
|
33 |
+
|
34 |
+
def __len__(self) -> int:
|
35 |
+
return len(self.data)
|
base/base_model.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class BaseModel(nn.Module):
|
7 |
+
"""
|
8 |
+
Base class for all models
|
9 |
+
"""
|
10 |
+
@abstractmethod
|
11 |
+
def forward(self, *inputs):
|
12 |
+
"""
|
13 |
+
Forward pass logic
|
14 |
+
|
15 |
+
:return: Model output
|
16 |
+
"""
|
17 |
+
raise NotImplementedError
|
18 |
+
|
19 |
+
def __str__(self):
|
20 |
+
"""
|
21 |
+
Model prints with number of trainable parameters
|
22 |
+
"""
|
23 |
+
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
|
24 |
+
params = sum([np.prod(p.size()) for p in model_parameters])
|
25 |
+
return super().__str__() + '\nTrainable parameters: {}'.format(params)
|
base/base_trainer.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
import torch
|
3 |
+
from logger import TensorboardWriter
|
4 |
+
from numpy import inf
|
5 |
+
|
6 |
+
|
7 |
+
class BaseTrainer:
|
8 |
+
"""
|
9 |
+
Base class for all trainers
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, model, criterion, metric_ftns, optimizer, config):
|
13 |
+
self.config = config
|
14 |
+
self.logger = config.get_logger('trainer', config['trainer']['verbosity'])
|
15 |
+
|
16 |
+
self.model = model
|
17 |
+
self.criterion = criterion
|
18 |
+
self.metric_ftns = metric_ftns
|
19 |
+
self.optimizer = optimizer
|
20 |
+
|
21 |
+
cfg_trainer = config['trainer']
|
22 |
+
self.epochs = cfg_trainer['epochs']
|
23 |
+
self.save_period = cfg_trainer['save_period']
|
24 |
+
self.monitor = cfg_trainer.get('monitor', 'off')
|
25 |
+
|
26 |
+
# configuration to monitor model performance and save best
|
27 |
+
if self.monitor == 'off':
|
28 |
+
self.mnt_mode = 'off'
|
29 |
+
self.mnt_best = 0
|
30 |
+
else:
|
31 |
+
self.mnt_mode, self.mnt_metric = self.monitor.split()
|
32 |
+
assert self.mnt_mode in ['min', 'max']
|
33 |
+
|
34 |
+
self.mnt_best = inf if self.mnt_mode == 'min' else -inf
|
35 |
+
self.early_stop = cfg_trainer.get('early_stop', inf)
|
36 |
+
if self.early_stop <= 0:
|
37 |
+
self.early_stop = inf
|
38 |
+
|
39 |
+
self.start_epoch = 1
|
40 |
+
|
41 |
+
self.checkpoint_dir = config.save_dir
|
42 |
+
|
43 |
+
# setup visualization writer instance
|
44 |
+
self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])
|
45 |
+
|
46 |
+
if config.resume is not None:
|
47 |
+
self._resume_checkpoint(config.resume)
|
48 |
+
|
49 |
+
@abstractmethod
|
50 |
+
def _train_epoch(self, epoch):
|
51 |
+
"""
|
52 |
+
Training logic for an epoch
|
53 |
+
|
54 |
+
:param epoch: Current epoch number
|
55 |
+
"""
|
56 |
+
raise NotImplementedError
|
57 |
+
|
58 |
+
def train(self):
|
59 |
+
"""
|
60 |
+
Full training logic
|
61 |
+
"""
|
62 |
+
not_improved_count = 0
|
63 |
+
for epoch in range(self.start_epoch, self.epochs + 1):
|
64 |
+
result = self._train_epoch(epoch)
|
65 |
+
|
66 |
+
# save logged informations into log dict
|
67 |
+
log = {'epoch': epoch}
|
68 |
+
log.update(result)
|
69 |
+
|
70 |
+
# print logged informations to the screen
|
71 |
+
for key, value in log.items():
|
72 |
+
self.logger.info(' {:15s}: {}'.format(str(key), value))
|
73 |
+
|
74 |
+
# evaluate model performance according to configured metric, save best checkpoint as model_best
|
75 |
+
best = False
|
76 |
+
if self.mnt_mode != 'off':
|
77 |
+
try:
|
78 |
+
# check whether model performance improved or not, according to specified metric(mnt_metric)
|
79 |
+
improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
|
80 |
+
(self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
|
81 |
+
except KeyError:
|
82 |
+
self.logger.warning("Warning: Metric '{}' is not found. "
|
83 |
+
"Model performance monitoring is disabled.".format(self.mnt_metric))
|
84 |
+
self.mnt_mode = 'off'
|
85 |
+
improved = False
|
86 |
+
|
87 |
+
if improved:
|
88 |
+
self.mnt_best = log[self.mnt_metric]
|
89 |
+
not_improved_count = 0
|
90 |
+
best = True
|
91 |
+
else:
|
92 |
+
not_improved_count += 1
|
93 |
+
|
94 |
+
if not_improved_count > self.early_stop:
|
95 |
+
self.logger.info("Validation performance didn\'t improve for {} epochs. "
|
96 |
+
"Training stops.".format(self.early_stop))
|
97 |
+
break
|
98 |
+
|
99 |
+
if epoch % self.save_period == 0:
|
100 |
+
self._save_checkpoint(epoch, save_best=best)
|
101 |
+
|
102 |
+
def _save_checkpoint(self, epoch, save_best=False):
|
103 |
+
"""
|
104 |
+
Saving checkpoints
|
105 |
+
|
106 |
+
:param epoch: current epoch number
|
107 |
+
:param log: logging information of the epoch
|
108 |
+
:param save_best: if True, rename the saved checkpoint to 'model_best.pth'
|
109 |
+
"""
|
110 |
+
arch = type(self.model).__name__
|
111 |
+
state = {
|
112 |
+
'arch': arch,
|
113 |
+
'epoch': epoch,
|
114 |
+
'state_dict': self.model.state_dict(),
|
115 |
+
'optimizer': self.optimizer.state_dict(),
|
116 |
+
'monitor_best': self.mnt_best,
|
117 |
+
'config': self.config
|
118 |
+
}
|
119 |
+
filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch))
|
120 |
+
torch.save(state, filename)
|
121 |
+
self.logger.info("Saving checkpoint: {} ...".format(filename))
|
122 |
+
if save_best:
|
123 |
+
best_path = str(self.checkpoint_dir / 'model_best.pth')
|
124 |
+
torch.save(state, best_path)
|
125 |
+
self.logger.info("Saving current best: model_best.pth ...")
|
126 |
+
|
127 |
+
def _resume_checkpoint(self, resume_path):
|
128 |
+
"""
|
129 |
+
Resume from saved checkpoints
|
130 |
+
|
131 |
+
:param resume_path: Checkpoint path to be resumed
|
132 |
+
"""
|
133 |
+
resume_path = str(resume_path)
|
134 |
+
self.logger.info("Loading checkpoint: {} ...".format(resume_path))
|
135 |
+
checkpoint = torch.load(resume_path)
|
136 |
+
self.start_epoch = checkpoint['epoch'] + 1
|
137 |
+
self.mnt_best = checkpoint['monitor_best']
|
138 |
+
|
139 |
+
# load architecture params from checkpoint.
|
140 |
+
if checkpoint['config']['arch'] != self.config['arch']:
|
141 |
+
self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
|
142 |
+
"checkpoint. This may yield an exception while state_dict is being loaded.")
|
143 |
+
self.model.load_state_dict(checkpoint['state_dict'])
|
144 |
+
|
145 |
+
# load optimizer state from checkpoint only when optimizer type is not changed.
|
146 |
+
if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
|
147 |
+
self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
|
148 |
+
"Optimizer parameters not being resumed.")
|
149 |
+
else:
|
150 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
151 |
+
|
152 |
+
self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
|
data_loader/data_loaders.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from base import BaseDataLoader
|
2 |
+
from dataset.datasets import PatchedDataset
|
3 |
+
from torchvision import transforms
|
4 |
+
from torch.utils.data.sampler import SequentialSampler
|
5 |
+
|
6 |
+
|
7 |
+
class PatchedDataLoader(BaseDataLoader):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
data_dir,
|
11 |
+
patch_size,
|
12 |
+
batch_size,
|
13 |
+
patch_stride=None,
|
14 |
+
preds=None,
|
15 |
+
target_dist=0.0,
|
16 |
+
shuffle=True,
|
17 |
+
validation_split=0.0,
|
18 |
+
num_workers=1,
|
19 |
+
training=True
|
20 |
+
):
|
21 |
+
trsfm = transforms.Compose([
|
22 |
+
transforms.ToTensor(),
|
23 |
+
transforms.Normalize((0.3551, 0.4698, 0.2261),
|
24 |
+
(0.1966, 0.1988, 0.1761))
|
25 |
+
])
|
26 |
+
target_trsfm = transforms.Compose([
|
27 |
+
transforms.ToTensor(),
|
28 |
+
])
|
29 |
+
rand_trsfm = transforms.RandomApply([
|
30 |
+
transforms.RandomVerticalFlip(),
|
31 |
+
transforms.RandomHorizontalFlip()
|
32 |
+
])
|
33 |
+
self.data_dir = data_dir
|
34 |
+
self.dataset = PatchedDataset(
|
35 |
+
self.data_dir,
|
36 |
+
patch_size,
|
37 |
+
patch_stride=patch_stride,
|
38 |
+
preds=preds,
|
39 |
+
target_dist=target_dist,
|
40 |
+
transform=trsfm,
|
41 |
+
target_transform=target_trsfm,
|
42 |
+
rand_transform=rand_trsfm if training and shuffle else None,
|
43 |
+
train=training,
|
44 |
+
late_init=True
|
45 |
+
)
|
46 |
+
super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)
|
47 |
+
|
48 |
+
def _split_sampler(self, split):
|
49 |
+
train_sampler, valid_sampler = super()._split_sampler(split)
|
50 |
+
|
51 |
+
if valid_sampler is not None:
|
52 |
+
self.dataset.make_dataset(valid_indices=valid_sampler.indices)
|
53 |
+
else:
|
54 |
+
self.dataset.make_dataset()
|
55 |
+
|
56 |
+
train_idx, valid_idx = [], []
|
57 |
+
for patch in self.dataset.patches:
|
58 |
+
if valid_sampler is not None and patch.idx in valid_sampler.indices:
|
59 |
+
valid_idx.append(self.dataset.patches.index(patch))
|
60 |
+
else:
|
61 |
+
train_idx.append(self.dataset.patches.index(patch))
|
62 |
+
|
63 |
+
if valid_sampler is not None:
|
64 |
+
train_sampler.indices, valid_sampler.indices = train_idx, valid_idx
|
65 |
+
else:
|
66 |
+
train_sampler = SequentialSampler(train_idx)
|
67 |
+
|
68 |
+
# turn off shuffle option which is mutually exclusive with sampler
|
69 |
+
self.shuffle = False
|
70 |
+
self.n_samples = len(train_idx)
|
71 |
+
|
72 |
+
return train_sampler, valid_sampler
|
73 |
+
|
74 |
+
def update_dataset(self, preds):
|
75 |
+
self.dataset.preds = preds
|
76 |
+
self.dataset.patches.clear()
|
77 |
+
self.n_samples = len(self.dataset)
|
78 |
+
|
79 |
+
train_sampler, valid_sampler = self._split_sampler(
|
80 |
+
self.validation_split)
|
81 |
+
if valid_sampler is not None:
|
82 |
+
self.valid_sampler.indices = valid_sampler.indices
|
83 |
+
self.sampler.indices = train_sampler.indices
|
dataset/datasets.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Callable, Optional
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from base import BaseDataset
|
4 |
+
from utils.util import TransformMultiple, pil_loader
|
5 |
+
from dataset.patches import Patches
|
6 |
+
|
7 |
+
|
8 |
+
class PatchedDataset(BaseDataset):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
root: str,
|
12 |
+
patch_size: int,
|
13 |
+
patch_stride: int = None,
|
14 |
+
preds: list = None,
|
15 |
+
target_dist: float = 0.0,
|
16 |
+
transforms: Optional[Callable] = None,
|
17 |
+
transform: Optional[Callable] = None,
|
18 |
+
target_transform: Optional[Callable] = None,
|
19 |
+
rand_transform: Optional[Callable] = None,
|
20 |
+
train: bool = True,
|
21 |
+
late_init: bool = False
|
22 |
+
) -> None:
|
23 |
+
super().__init__(root, pil_loader, transforms, transform, target_transform, train)
|
24 |
+
self.patches = Patches(patch_size, patch_stride)
|
25 |
+
self.preds = preds
|
26 |
+
self.target_dist = target_dist * patch_size ** 2
|
27 |
+
self.rand_transform = TransformMultiple(rand_transform)
|
28 |
+
if not late_init:
|
29 |
+
self.make_dataset()
|
30 |
+
|
31 |
+
def make_dataset(self, valid_indices=[]):
|
32 |
+
for idx in range(super().__len__()):
|
33 |
+
_, mask = super().__getitem__(idx)
|
34 |
+
if self.preds is not None:
|
35 |
+
mask = self._union_mask(mask, self.preds[idx])
|
36 |
+
|
37 |
+
if idx not in valid_indices:
|
38 |
+
self.patches.create(idx, mask, cond_fn=self._dist_fn
|
39 |
+
if self.target_dist != 0.0 else None)
|
40 |
+
else:
|
41 |
+
self.patches.create(idx, mask, no_overlap=True)
|
42 |
+
|
43 |
+
def __getitem__(self, index: int) -> Any:
|
44 |
+
patch = self.patches[index]
|
45 |
+
img, mask = super().__getitem__(patch.idx)
|
46 |
+
|
47 |
+
img_patch = self.patches.get_patch(img, patch)
|
48 |
+
mask_patch = self.patches.get_patch(mask, patch)
|
49 |
+
img_patch, mask_patch = self.rand_transform(
|
50 |
+
(img_patch, mask_patch.unsqueeze(dim=0)))
|
51 |
+
return img_patch, mask_patch.squeeze(dim=0)
|
52 |
+
|
53 |
+
def _union_mask(self, mask, pred):
|
54 |
+
pred = F.pad(
|
55 |
+
pred, (0, mask.shape[1] - pred.shape[1], 0, mask.shape[0] - pred.shape[0]))
|
56 |
+
return (mask + pred) - (mask * pred)
|
57 |
+
|
58 |
+
def _dist_fn(self, mask, patch):
|
59 |
+
data = self.patches.get_patch(mask, patch)
|
60 |
+
return data.count_nonzero() > self.target_dist
|
dataset/patches.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import UserList
|
2 |
+
from torchvision.utils import make_grid
|
3 |
+
|
4 |
+
|
5 |
+
class Patch():
|
6 |
+
def __init__(self, idx, x, y) -> None:
|
7 |
+
self.idx = idx
|
8 |
+
self.x, self.y = x, y
|
9 |
+
self.data = None
|
10 |
+
|
11 |
+
def __eq__(self, __o: object) -> bool:
|
12 |
+
return self.idx == __o.idx and self.x == __o.x and self.y == __o.y
|
13 |
+
|
14 |
+
|
15 |
+
class Patches(UserList):
|
16 |
+
def __init__(self, size, stride=None):
|
17 |
+
super().__init__()
|
18 |
+
self.size = size
|
19 |
+
self.stride = stride if stride is not None else size
|
20 |
+
|
21 |
+
def create(self, index, data, cond_fn=None, no_overlap=False):
|
22 |
+
stride = self.size if no_overlap else self.stride
|
23 |
+
for x in range(0, data.size(-2) - self.size + 1, stride):
|
24 |
+
for y in range(0, data.size(-1) - self.size + 1, stride):
|
25 |
+
patch = Patch(index, x, y)
|
26 |
+
if cond_fn is None or cond_fn(data, patch):
|
27 |
+
self.append(patch)
|
28 |
+
|
29 |
+
def get_patch(self, data, patch: Patch):
|
30 |
+
assert data.ndim in {2, 3}, 'only 2-D and 3-D Tensors are supported.'
|
31 |
+
_data = data.unsqueeze(dim=0) if data.ndim == 2 else data
|
32 |
+
data_patch = _data[:, patch.x:patch.x + self.size,
|
33 |
+
patch.y:patch.y + self.size]
|
34 |
+
return data_patch.squeeze(dim=0) if data.ndim == 2 else data_patch
|
35 |
+
|
36 |
+
def store_data(self, indices, data):
|
37 |
+
for idx in range(len(indices)):
|
38 |
+
self[indices[idx]].data = [data[i][idx] for i in range(len(data))]
|
39 |
+
|
40 |
+
def retrieve_data(self, indices):
|
41 |
+
return [[self[idx].data[i] for idx in indices] for i in range(len(self[indices[0]].data))]
|
42 |
+
|
43 |
+
def combine(self, index: int, data_idx: int):
|
44 |
+
indices = [self.index(patch) for patch in self if patch.idx == index]
|
45 |
+
indices.sort(key=lambda idx: (self[idx].x, self[idx].y))
|
46 |
+
|
47 |
+
data = self.retrieve_data(indices)
|
48 |
+
nrow = sum([self[idx].x == 0 for idx in indices])
|
49 |
+
return make_grid(data[data_idx], nrow, padding=0)
|
logger/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .logger import *
|
2 |
+
from .visualization import *
|
logger/logger.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import logging.config
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from utils import read_json
|
6 |
+
|
7 |
+
|
8 |
+
def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO):
|
9 |
+
"""
|
10 |
+
Setup logging configuration
|
11 |
+
"""
|
12 |
+
log_config = Path(log_config)
|
13 |
+
if log_config.is_file():
|
14 |
+
config = read_json(log_config)
|
15 |
+
# modify logging paths based on run config
|
16 |
+
for _, handler in config['handlers'].items():
|
17 |
+
if 'filename' in handler:
|
18 |
+
handler['filename'] = str(save_dir / handler['filename'])
|
19 |
+
|
20 |
+
logging.config.dictConfig(config)
|
21 |
+
else:
|
22 |
+
print("Warning: logging configuration file is not found in {}.".format(log_config))
|
23 |
+
logging.basicConfig(level=default_level)
|
logger/logger_config.json
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": 1,
|
3 |
+
"disable_existing_loggers": false,
|
4 |
+
"formatters": {
|
5 |
+
"simple": {
|
6 |
+
"format": "%(message)s"
|
7 |
+
},
|
8 |
+
"datetime": {
|
9 |
+
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
10 |
+
}
|
11 |
+
},
|
12 |
+
"handlers": {
|
13 |
+
"console": {
|
14 |
+
"class": "logging.StreamHandler",
|
15 |
+
"level": "DEBUG",
|
16 |
+
"formatter": "simple",
|
17 |
+
"stream": "ext://sys.stdout"
|
18 |
+
},
|
19 |
+
"info_file_handler": {
|
20 |
+
"class": "logging.handlers.RotatingFileHandler",
|
21 |
+
"level": "INFO",
|
22 |
+
"formatter": "datetime",
|
23 |
+
"filename": "info.log",
|
24 |
+
"maxBytes": 10485760,
|
25 |
+
"backupCount": 20,
|
26 |
+
"encoding": "utf8"
|
27 |
+
}
|
28 |
+
},
|
29 |
+
"root": {
|
30 |
+
"level": "INFO",
|
31 |
+
"handlers": [
|
32 |
+
"console",
|
33 |
+
"info_file_handler"
|
34 |
+
]
|
35 |
+
}
|
36 |
+
}
|
logger/visualization.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from datetime import datetime
|
3 |
+
|
4 |
+
|
5 |
+
class TensorboardWriter():
|
6 |
+
def __init__(self, log_dir, logger, enabled):
|
7 |
+
self.writer = None
|
8 |
+
self.selected_module = ""
|
9 |
+
|
10 |
+
if enabled:
|
11 |
+
log_dir = str(log_dir)
|
12 |
+
|
13 |
+
# Retrieve vizualization writer.
|
14 |
+
succeeded = False
|
15 |
+
for module in ["torch.utils.tensorboard", "tensorboardX"]:
|
16 |
+
try:
|
17 |
+
self.writer = importlib.import_module(module).SummaryWriter(log_dir)
|
18 |
+
succeeded = True
|
19 |
+
break
|
20 |
+
except ImportError:
|
21 |
+
succeeded = False
|
22 |
+
self.selected_module = module
|
23 |
+
|
24 |
+
if not succeeded:
|
25 |
+
message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
|
26 |
+
"this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \
|
27 |
+
"version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file."
|
28 |
+
logger.warning(message)
|
29 |
+
|
30 |
+
self.step = 0
|
31 |
+
self.mode = ''
|
32 |
+
self.step_tracker = {}
|
33 |
+
|
34 |
+
self.tb_writer_ftns = {
|
35 |
+
'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
|
36 |
+
'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
|
37 |
+
}
|
38 |
+
self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
|
39 |
+
self.timer = datetime.now()
|
40 |
+
|
41 |
+
def set_step(self, step, mode='train'):
|
42 |
+
self.mode = mode
|
43 |
+
self.step = step
|
44 |
+
if step == 0:
|
45 |
+
self.timer = datetime.now()
|
46 |
+
else:
|
47 |
+
duration = datetime.now() - self.timer
|
48 |
+
self.add_scalar('steps_per_sec', 1 / duration.total_seconds())
|
49 |
+
self.timer = datetime.now()
|
50 |
+
|
51 |
+
def next(self, mode='train'):
|
52 |
+
step = self.step_tracker[mode] = self.step_tracker.get(mode, 0) + 1
|
53 |
+
self.set_step(step, mode=mode)
|
54 |
+
|
55 |
+
def __getattr__(self, name):
|
56 |
+
"""
|
57 |
+
If visualization is configured to use:
|
58 |
+
return add_data() methods of tensorboard with additional information (step, tag) added.
|
59 |
+
Otherwise:
|
60 |
+
return a blank function handle that does nothing
|
61 |
+
"""
|
62 |
+
if name in self.tb_writer_ftns:
|
63 |
+
add_data = getattr(self.writer, name, None)
|
64 |
+
|
65 |
+
def wrapper(tag, data, *args, **kwargs):
|
66 |
+
if add_data is not None:
|
67 |
+
# add mode(train/valid) tag
|
68 |
+
if name not in self.tag_mode_exceptions:
|
69 |
+
tag = '{}/{}'.format(tag, self.mode)
|
70 |
+
add_data(tag, data, self.step, *args, **kwargs)
|
71 |
+
return wrapper
|
72 |
+
else:
|
73 |
+
# default action for returning methods defined in this class, set_step() for instance.
|
74 |
+
try:
|
75 |
+
attr = object.__getattr__(name)
|
76 |
+
except AttributeError:
|
77 |
+
raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
|
78 |
+
return attr
|
model/loss.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def focal_loss(inputs, targets, alpha=0.5, gamma=2, reduction='mean'):
|
6 |
+
logpt = F.cross_entropy(inputs, targets.long(), reduction='none')
|
7 |
+
pt = torch.exp(-logpt)
|
8 |
+
focal_loss = (1 - pt) ** gamma * logpt
|
9 |
+
alpha_weight = alpha * targets + (1 - alpha) * (1 - targets)
|
10 |
+
focal_loss = alpha_weight * focal_loss
|
11 |
+
|
12 |
+
if reduction == 'mean':
|
13 |
+
return torch.mean(focal_loss)
|
14 |
+
elif reduction == 'sum':
|
15 |
+
return torch.sum(focal_loss)
|
16 |
+
else:
|
17 |
+
return focal_loss
|
18 |
+
|
19 |
+
|
20 |
+
def dice_loss(inputs, targets, epsilon=1e-7):
|
21 |
+
targets_one_hot = torch.nn.functional.one_hot(targets.long(), num_classes=inputs.shape[1])
|
22 |
+
targets_one_hot = targets_one_hot.permute(0, 3, 1, 2).float()
|
23 |
+
inputs = F.softmax(inputs, dim=1)
|
24 |
+
targets_one_hot = targets_one_hot.type(inputs.type())
|
25 |
+
numerator = 2 * (inputs * targets_one_hot).sum(dim=(2,3))
|
26 |
+
denominator = inputs.sum(dim=(2,3)) + targets_one_hot.sum(dim=(2,3))
|
27 |
+
dice_coefficient = numerator / (denominator + epsilon)
|
28 |
+
return 1 - dice_coefficient.mean()
|
model/metric.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from sklearn.metrics import f1_score as f1
|
3 |
+
from sklearn.metrics import precision_score, recall_score
|
4 |
+
|
5 |
+
|
6 |
+
def precision(output, target):
|
7 |
+
with torch.no_grad():
|
8 |
+
pred = torch.argmax(output, dim=1)
|
9 |
+
assert pred.shape[0] == len(target)
|
10 |
+
return precision_score(target.view(-1).cpu(), pred.view(-1).cpu())
|
11 |
+
|
12 |
+
|
13 |
+
def recall(output, target):
|
14 |
+
with torch.no_grad():
|
15 |
+
pred = torch.argmax(output, dim=1)
|
16 |
+
assert pred.shape[0] == len(target)
|
17 |
+
return recall_score(target.view(-1).cpu(), pred.view(-1).cpu())
|
18 |
+
|
19 |
+
|
20 |
+
def f1_score(output, target):
|
21 |
+
with torch.no_grad():
|
22 |
+
pred = torch.argmax(output, dim=1)
|
23 |
+
assert pred.shape[0] == len(target)
|
24 |
+
return f1(target.view(-1).cpu(), pred.view(-1).cpu())
|
model/model.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from base import BaseModel
|
5 |
+
|
6 |
+
|
7 |
+
class UNet(BaseModel):
|
8 |
+
def __init__(self, n_channels, n_classes, bilinear=False):
|
9 |
+
super(UNet, self).__init__()
|
10 |
+
self.n_channels = n_channels
|
11 |
+
self.n_classes = n_classes
|
12 |
+
self.bilinear = bilinear
|
13 |
+
|
14 |
+
self.inc = DoubleConv(n_channels, 64)
|
15 |
+
self.down1 = Down(64, 128)
|
16 |
+
self.down2 = Down(128, 256)
|
17 |
+
self.down3 = Down(256, 512)
|
18 |
+
factor = 2 if bilinear else 1
|
19 |
+
self.down4 = Down(512, 1024 // factor)
|
20 |
+
self.up1 = Up(1024, 512 // factor, bilinear)
|
21 |
+
self.up2 = Up(512, 256 // factor, bilinear)
|
22 |
+
self.up3 = Up(256, 128 // factor, bilinear)
|
23 |
+
self.up4 = Up(128, 64, bilinear)
|
24 |
+
self.outc = OutConv(64, n_classes)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x1 = self.inc(x)
|
28 |
+
x2 = self.down1(x1)
|
29 |
+
x3 = self.down2(x2)
|
30 |
+
x4 = self.down3(x3)
|
31 |
+
x5 = self.down4(x4)
|
32 |
+
x = self.up1(x5, x4)
|
33 |
+
x = self.up2(x, x3)
|
34 |
+
x = self.up3(x, x2)
|
35 |
+
x = self.up4(x, x1)
|
36 |
+
logits = self.outc(x)
|
37 |
+
return F.log_softmax(logits, dim=1)
|
38 |
+
|
39 |
+
|
40 |
+
class DoubleConv(BaseModel):
|
41 |
+
"""(convolution => [BN] => ReLU) * 2"""
|
42 |
+
|
43 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
44 |
+
super().__init__()
|
45 |
+
if not mid_channels:
|
46 |
+
mid_channels = out_channels
|
47 |
+
self.double_conv = nn.Sequential(
|
48 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
|
49 |
+
nn.BatchNorm2d(mid_channels),
|
50 |
+
nn.ReLU(inplace=True),
|
51 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
52 |
+
nn.BatchNorm2d(out_channels),
|
53 |
+
nn.ReLU(inplace=True)
|
54 |
+
)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
return self.double_conv(x)
|
58 |
+
|
59 |
+
|
60 |
+
class Down(BaseModel):
|
61 |
+
"""Downscaling with maxpool then double conv"""
|
62 |
+
|
63 |
+
def __init__(self, in_channels, out_channels):
|
64 |
+
super().__init__()
|
65 |
+
self.maxpool_conv = nn.Sequential(
|
66 |
+
nn.MaxPool2d(kernel_size=8, stride=2, padding=3),
|
67 |
+
DoubleConv(in_channels, out_channels)
|
68 |
+
)
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
return self.maxpool_conv(x)
|
72 |
+
|
73 |
+
|
74 |
+
class Up(BaseModel):
|
75 |
+
"""Upscaling then double conv"""
|
76 |
+
|
77 |
+
def __init__(self, in_channels, out_channels, bilinear=True):
|
78 |
+
super().__init__()
|
79 |
+
|
80 |
+
# if bilinear, use the normal convolutions to reduce the number of channels
|
81 |
+
if bilinear:
|
82 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
83 |
+
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
|
84 |
+
else:
|
85 |
+
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
|
86 |
+
self.conv = DoubleConv(in_channels, out_channels)
|
87 |
+
|
88 |
+
def forward(self, x1, x2):
|
89 |
+
x1 = self.up(x1)
|
90 |
+
# input is CHW
|
91 |
+
diffY = x2.size()[2] - x1.size()[2]
|
92 |
+
diffX = x2.size()[3] - x1.size()[3]
|
93 |
+
|
94 |
+
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
|
95 |
+
diffY // 2, diffY - diffY // 2])
|
96 |
+
x = torch.cat([x2, x1], dim=1)
|
97 |
+
return self.conv(x)
|
98 |
+
|
99 |
+
|
100 |
+
class OutConv(BaseModel):
|
101 |
+
def __init__(self, in_channels, out_channels):
|
102 |
+
super(OutConv, self).__init__()
|
103 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
return self.conv(x)
|
trainer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .trainer import *
|
trainer/trainer.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from base import BaseTrainer
|
4 |
+
from torchvision.utils import make_grid
|
5 |
+
from utils import MetricTracker
|
6 |
+
|
7 |
+
|
8 |
+
class Trainer(BaseTrainer):
|
9 |
+
"""
|
10 |
+
Trainer class
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, model, criterion, metric_ftns, optimizer, config, device,
|
14 |
+
data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None):
|
15 |
+
super().__init__(model, criterion, metric_ftns, optimizer, config)
|
16 |
+
self.config = config
|
17 |
+
self.device = device
|
18 |
+
self.data_loader = data_loader
|
19 |
+
self.len_epoch = len(self.data_loader)
|
20 |
+
self.valid_data_loader = valid_data_loader
|
21 |
+
self.do_validation = self.valid_data_loader is not None
|
22 |
+
self.lr_scheduler = lr_scheduler
|
23 |
+
self.log_step = int(np.sqrt(data_loader.batch_size))
|
24 |
+
self.adaptive_step = config['trainer']['adaptive_step']
|
25 |
+
|
26 |
+
self.train_metrics = MetricTracker(
|
27 |
+
'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
|
28 |
+
self.valid_metrics = MetricTracker(
|
29 |
+
'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
|
30 |
+
|
31 |
+
def _train_epoch(self, epoch):
|
32 |
+
"""
|
33 |
+
Training logic for an epoch
|
34 |
+
|
35 |
+
:param epoch: Integer, current training epoch.
|
36 |
+
:return: A log that contains average loss and metric in this epoch.
|
37 |
+
"""
|
38 |
+
if epoch > self.adaptive_step and epoch % self.adaptive_step == 1:
|
39 |
+
dataset = self.data_loader.inference.dataset
|
40 |
+
self.model.eval()
|
41 |
+
with torch.no_grad():
|
42 |
+
for batch_idx, (data, target) in enumerate(self.data_loader.inference):
|
43 |
+
data, target = data.to(self.device), target.to(self.device)
|
44 |
+
output = self.model(data)
|
45 |
+
|
46 |
+
batch_size = self.data_loader.inference.batch_size
|
47 |
+
patch_idx = torch.arange(
|
48 |
+
batch_size * batch_idx, batch_size * batch_idx + data.shape[0])
|
49 |
+
pred = torch.argmax(output, dim=1)
|
50 |
+
dataset.patches.store_data(patch_idx, [pred.unsqueeze(1)])
|
51 |
+
|
52 |
+
preds = [dataset.patches.combine(idx, data_idx=0)[0].cpu()
|
53 |
+
for idx in range(len(dataset.data))]
|
54 |
+
|
55 |
+
self.data_loader.update_dataset(preds)
|
56 |
+
self.len_epoch = len(self.data_loader)
|
57 |
+
|
58 |
+
self.model.train()
|
59 |
+
self.train_metrics.reset()
|
60 |
+
for batch_idx, (data, target) in enumerate(self.data_loader):
|
61 |
+
data, target = data.to(self.device), target.to(self.device)
|
62 |
+
|
63 |
+
self.optimizer.zero_grad()
|
64 |
+
output = self.model(data)
|
65 |
+
loss = self.criterion(output, target)
|
66 |
+
loss.backward()
|
67 |
+
self.optimizer.step()
|
68 |
+
|
69 |
+
self.train_metrics.update('loss', loss.item())
|
70 |
+
for met in self.metric_ftns:
|
71 |
+
self.train_metrics.update(met.__name__, met(output, target))
|
72 |
+
|
73 |
+
if batch_idx % self.log_step == 0:
|
74 |
+
self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
|
75 |
+
epoch,
|
76 |
+
self._progress(batch_idx),
|
77 |
+
loss.item()))
|
78 |
+
self.writer.add_image('input', make_grid(
|
79 |
+
data.cpu(), nrow=8, normalize=True))
|
80 |
+
|
81 |
+
if batch_idx == self.len_epoch:
|
82 |
+
break
|
83 |
+
|
84 |
+
self.writer.next()
|
85 |
+
self.train_metrics.add_scalers()
|
86 |
+
log = self.train_metrics.result()
|
87 |
+
|
88 |
+
if self.do_validation:
|
89 |
+
val_log = self._valid_epoch(epoch)
|
90 |
+
log.update(**{'val_' + k: v for k, v in val_log.items()})
|
91 |
+
|
92 |
+
if self.lr_scheduler is not None:
|
93 |
+
self.lr_scheduler.step()
|
94 |
+
return log
|
95 |
+
|
96 |
+
def _valid_epoch(self, epoch):
|
97 |
+
"""
|
98 |
+
Validate after training an epoch
|
99 |
+
|
100 |
+
:param epoch: Integer, current training epoch.
|
101 |
+
:return: A log that contains information about validation
|
102 |
+
"""
|
103 |
+
self.model.eval()
|
104 |
+
self.valid_metrics.reset()
|
105 |
+
with torch.no_grad():
|
106 |
+
for batch_idx, (data, target) in enumerate(self.valid_data_loader):
|
107 |
+
data, target = data.to(self.device), target.to(self.device)
|
108 |
+
|
109 |
+
output = self.model(data)
|
110 |
+
loss = self.criterion(output, target)
|
111 |
+
|
112 |
+
self.valid_metrics.update('loss', loss.item())
|
113 |
+
for met in self.metric_ftns:
|
114 |
+
self.valid_metrics.update(
|
115 |
+
met.__name__, met(output, target))
|
116 |
+
self.writer.add_image('input', make_grid(
|
117 |
+
data.cpu(), nrow=8, normalize=True))
|
118 |
+
|
119 |
+
# add histogram of model parameters to the tensorboard
|
120 |
+
for name, p in self.model.named_parameters():
|
121 |
+
self.writer.add_histogram(name, p, bins='auto')
|
122 |
+
|
123 |
+
self.writer.next('valid')
|
124 |
+
self.valid_metrics.add_scalers()
|
125 |
+
return self.valid_metrics.result()
|
126 |
+
|
127 |
+
def _progress(self, batch_idx):
|
128 |
+
base = '[{}/{} ({:.0f}%)]'
|
129 |
+
if hasattr(self.data_loader, 'n_samples'):
|
130 |
+
current = batch_idx * self.data_loader.batch_size
|
131 |
+
total = self.data_loader.n_samples
|
132 |
+
else:
|
133 |
+
current = batch_idx
|
134 |
+
total = self.len_epoch
|
135 |
+
return base.format(current, total, 100.0 * current / total)
|
utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .util import *
|
utils/util.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from collections import OrderedDict
|
3 |
+
from itertools import repeat
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Any, Callable, List, Optional, Tuple
|
6 |
+
import pandas as pd
|
7 |
+
import PIL.Image as Image
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
def read_json(fname):
|
12 |
+
fname = Path(fname)
|
13 |
+
with fname.open('rt') as handle:
|
14 |
+
return json.load(handle, object_hook=OrderedDict)
|
15 |
+
|
16 |
+
|
17 |
+
def write_json(content, fname):
|
18 |
+
fname = Path(fname)
|
19 |
+
with fname.open('wt') as handle:
|
20 |
+
json.dump(content, handle, indent=4, sort_keys=False)
|
21 |
+
|
22 |
+
|
23 |
+
def pil_loader(fname: str) -> Image.Image:
|
24 |
+
return Image.open(fname)
|
25 |
+
|
26 |
+
|
27 |
+
def prepare_device(n_gpu_use):
|
28 |
+
"""
|
29 |
+
setup GPU device if available. get gpu device indices which are used for DataParallel
|
30 |
+
"""
|
31 |
+
n_gpu = torch.cuda.device_count()
|
32 |
+
if n_gpu_use > 0 and n_gpu == 0:
|
33 |
+
print("Warning: There\'s no GPU available on this machine,"
|
34 |
+
"training will be performed on CPU.")
|
35 |
+
n_gpu_use = 0
|
36 |
+
if n_gpu_use > n_gpu:
|
37 |
+
print(f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are "
|
38 |
+
"available on this machine.")
|
39 |
+
n_gpu_use = n_gpu
|
40 |
+
device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
|
41 |
+
list_ids = list(range(n_gpu_use))
|
42 |
+
return device, list_ids
|
43 |
+
|
44 |
+
|
45 |
+
class TransformMultiple:
|
46 |
+
def __init__(self, transform: Optional[Callable] = None) -> None:
|
47 |
+
self.transform = transform
|
48 |
+
|
49 |
+
def __call__(self, data: Any) -> Tuple:
|
50 |
+
if self.transform is not None:
|
51 |
+
cat_data = torch.cat(data)
|
52 |
+
cat_data = self.transform(cat_data)
|
53 |
+
return torch.split(cat_data, [x.size()[0] for x in data])
|
54 |
+
return data
|
55 |
+
|
56 |
+
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
|
57 |
+
lines = transform.__repr__().splitlines()
|
58 |
+
return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
|
59 |
+
|
60 |
+
def __repr__(self) -> str:
|
61 |
+
body = [self.__class__.__name__]
|
62 |
+
if self.transform is not None:
|
63 |
+
body += self._format_transform_repr(self.transform, "Transform: ")
|
64 |
+
return "\n".join(body)
|
65 |
+
|
66 |
+
|
67 |
+
class MetricTracker:
|
68 |
+
def __init__(self, *keys, writer=None):
|
69 |
+
self.writer = writer
|
70 |
+
self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average'])
|
71 |
+
self.reset()
|
72 |
+
|
73 |
+
def reset(self):
|
74 |
+
for col in self._data.columns:
|
75 |
+
self._data[col].values[:] = 0
|
76 |
+
|
77 |
+
def update(self, key, value, n=1):
|
78 |
+
self._data.total[key] += value * n
|
79 |
+
self._data.counts[key] += n
|
80 |
+
self._data.average[key] = self._data.total[key] / self._data.counts[key]
|
81 |
+
|
82 |
+
def add_scalers(self):
|
83 |
+
if self.writer is not None:
|
84 |
+
for key in self._data.index:
|
85 |
+
self.writer.add_scalar(key, self._data.average[key])
|
86 |
+
|
87 |
+
def avg(self, key):
|
88 |
+
return self._data.average[key]
|
89 |
+
|
90 |
+
def result(self):
|
91 |
+
return dict(self._data.average)
|