import torch.utils.data from data.base_data_loader import BaseDataLoader def CreateDataset(opt): dataset = None from data.aligned_dataset import AlignedDataset dataset = AlignedDataset() print("dataset [%s] was created" % (dataset.name())) dataset.initialize(opt) return dataset class CustomDatasetDataLoader(BaseDataLoader): def name(self): return 'CustomDatasetDataLoader' def initialize(self, opt): BaseDataLoader.initialize(self, opt) self.dataset = CreateDataset(opt) self.dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=opt.batchSize, sampler=data_sampler(self.dataset, not opt.serial_batches, opt.distributed), num_workers=int(opt.nThreads), pin_memory=True) def get_loader(self): return self.dataloader def __len__(self): return min(len(self.dataset), self.opt.max_dataset_size) def data_sampler(dataset, shuffle, distributed): if distributed: return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) if shuffle: return torch.utils.data.RandomSampler(dataset) else: return torch.utils.data.SequentialSampler(dataset) def sample_data(loader): while True: for batch in loader: yield batch class CustomTestDataLoader(BaseDataLoader): def name(self): return 'CustomDatasetDataLoader' def initialize(self, opt): BaseDataLoader.initialize(self, opt) self.dataset = CreateDataset(opt) self.dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=opt.batchSize, num_workers=int(opt.nThreads), pin_memory=True) def get_loader(self): return self.dataloader def __len__(self): return min(len(self.dataset), self.opt.max_dataset_size)