|
import pandas as pd |
|
import os |
|
from pathlib import Path |
|
from monai.data import Dataset, DataLoader |
|
import numpy as np |
|
from torch.utils.data import Subset |
|
import pytorch_lightning as pl |
|
from torchsampler import ImbalancedDatasetSampler |
|
from torch.utils.data.dataloader import default_collate |
|
|
|
from transforms import get_training_transforms, get_base_transforms |
|
|
|
class VerseDataModule(pl.LightningDataModule): |
|
def __init__(self, hparams): |
|
super().__init__() |
|
self.save_hyperparameters(dict(hparams), logger=False) |
|
self.data_dir = Path(self.hparams.dataset_path) |
|
self.csv_path = self.data_dir / 'fxall_labels.csv' |
|
|
|
if "modelsgenesis" in hparams.transforms: |
|
self.image_dir = self.data_dir / 'raw' |
|
else: |
|
self.image_dir = self.data_dir / 'ct' |
|
|
|
if not os.path.exists(self.image_dir): |
|
|
|
self.image_dir = self.data_dir |
|
self.csv_path = self.data_dir / 'slice_labels.csv' |
|
|
|
self.mask_dir = self.data_dir / 'seg' |
|
if hparams.mask != 'none' and not os.path.exists(self.mask_dir): |
|
raise RuntimeError("Configured to use masks, but 'seg' folder missing in dataset path") |
|
|
|
self.df = pd.read_csv(self.csv_path, index_col=0) |
|
|
|
|
|
if "path" in self.df.columns: |
|
self.df = self.df[self.df["path"].apply(lambda p: os.path.exists(self.image_dir / p))] |
|
|
|
|
|
if "image" not in self.df.columns: |
|
self.df['image'] = self.df['path'] |
|
|
|
self.transforms = { |
|
'training': get_training_transforms(hparams, self.image_dir, self.mask_dir), |
|
'validation': get_base_transforms(hparams, self.image_dir, self.mask_dir), |
|
'test': get_base_transforms(hparams, self.image_dir, self.mask_dir) |
|
} |
|
|
|
self.datasets = {} |
|
self.idxs = {} |
|
|
|
def setup(self, stage=None): |
|
|
|
|
|
|
|
|
|
|
|
if stage == 'fit' or stage is None: |
|
phases = ['training', 'validation'] |
|
else: |
|
phases = ['test'] |
|
|
|
for split in phases: |
|
|
|
idxs = self.df[f'split_{self.hparams.fold}'] == split |
|
|
|
idxs = np.where(idxs)[0] |
|
self.idxs[split] = idxs |
|
self.datasets[split] = Dataset( |
|
self.df.iloc[idxs].to_dict('records'), |
|
transform=self.transforms[split] |
|
) |
|
|
|
def get_label(self, data): |
|
train_df = self.df.iloc[self.idxs['training']] |
|
grading = train_df.fx_grading |
|
if self.hparams.task == 'detection': |
|
return train_df.fx |
|
elif self.hparams.task == 'grading': |
|
return grading |
|
elif self.hparams.task == 'simple_grading': |
|
if grading in [2,3]: |
|
return 1 |
|
if grading>3: |
|
return grading-2 |
|
else: |
|
return grading |
|
|
|
def train_dataloader(self): |
|
return DataLoader( |
|
self.datasets['training'], |
|
batch_size=self.hparams.batch_size, |
|
sampler=ImbalancedDatasetSampler( |
|
num_samples=self.df.iloc[self.idxs['training']].fx.sum() * 2, |
|
dataset=self.datasets['training'], |
|
callback_get_label=self.get_label, |
|
) if self.hparams.oversampling else None, |
|
num_workers=2, |
|
shuffle=not self.hparams.oversampling |
|
) |
|
|
|
def val_dataloader(self): |
|
return DataLoader( |
|
self.datasets['validation'], |
|
batch_size=self.hparams.batch_size, |
|
num_workers=8, |
|
) |
|
|
|
def test_dataloader(self): |
|
return DataLoader( |
|
self.datasets['test'], |
|
batch_size=self.hparams.batch_size, |
|
num_workers=8, |
|
) |
|
|