Paul Engstler
Initial commit
92f0e98
import pandas as pd
import os
from pathlib import Path
from monai.data import Dataset, DataLoader
import numpy as np
from torch.utils.data import Subset
import pytorch_lightning as pl
from torchsampler import ImbalancedDatasetSampler
from torch.utils.data.dataloader import default_collate
from transforms import get_training_transforms, get_base_transforms
class VerseDataModule(pl.LightningDataModule):
def __init__(self, hparams):
super().__init__()
self.save_hyperparameters(dict(hparams), logger=False)
self.data_dir = Path(self.hparams.dataset_path)
self.csv_path = self.data_dir / 'fxall_labels.csv'
if "modelsgenesis" in hparams.transforms:
self.image_dir = self.data_dir / 'raw'
else:
self.image_dir = self.data_dir / 'ct'
if not os.path.exists(self.image_dir):
# legacy support
self.image_dir = self.data_dir
self.csv_path = self.data_dir / 'slice_labels.csv'
self.mask_dir = self.data_dir / 'seg'
if hparams.mask != 'none' and not os.path.exists(self.mask_dir):
raise RuntimeError("Configured to use masks, but 'seg' folder missing in dataset path")
self.df = pd.read_csv(self.csv_path, index_col=0)
# TODO temporary fix to check for non-existing files
if "path" in self.df.columns:
self.df = self.df[self.df["path"].apply(lambda p: os.path.exists(self.image_dir / p))]
# FIXME slice_labels.csv provides the path in 'image', fxall_labels.csv in 'path'
if "image" not in self.df.columns:
self.df['image'] = self.df['path']
self.transforms = {
'training': get_training_transforms(hparams, self.image_dir, self.mask_dir),
'validation': get_base_transforms(hparams, self.image_dir, self.mask_dir),
'test': get_base_transforms(hparams, self.image_dir, self.mask_dir)
}
self.datasets = {}
self.idxs = {}
def setup(self, stage=None):
# dropping samples without fracture grading
# graded_idxs = ~self.df.fx.isna()
# vertebrae_level_idxs = self.df.level_idx >= self.hparams.min_vertebrae_level
# included_idxs = graded_idxs & vertebrae_level_idxs
if stage == 'fit' or stage is None:
phases = ['training', 'validation']
else:
phases = ['test']
for split in phases:
# get official verse partitions
idxs = self.df[f'split_{self.hparams.fold}'] == split
# idxs = np.where(included_idxs & idxs)[0]
idxs = np.where(idxs)[0]
self.idxs[split] = idxs
self.datasets[split] = Dataset(
self.df.iloc[idxs].to_dict('records'),
transform=self.transforms[split]
)
def get_label(self, data):
train_df = self.df.iloc[self.idxs['training']]
grading = train_df.fx_grading
if self.hparams.task == 'detection':
return train_df.fx
elif self.hparams.task == 'grading':
return grading
elif self.hparams.task == 'simple_grading':
if grading in [2,3]:
return 1
if grading>3:
return grading-2
else:
return grading
def train_dataloader(self):
return DataLoader(
self.datasets['training'],
batch_size=self.hparams.batch_size,
sampler=ImbalancedDatasetSampler(
num_samples=self.df.iloc[self.idxs['training']].fx.sum() * 2,
dataset=self.datasets['training'],
callback_get_label=self.get_label,
) if self.hparams.oversampling else None,
num_workers=2,
shuffle=not self.hparams.oversampling
)
def val_dataloader(self):
return DataLoader(
self.datasets['validation'],
batch_size=self.hparams.batch_size,
num_workers=8,
)
def test_dataloader(self):
return DataLoader(
self.datasets['test'],
batch_size=self.hparams.batch_size,
num_workers=8,
)