Spaces:
Running
Running
File size: 2,028 Bytes
de79343 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import torch.utils.data
from data.base_data_loader import BaseDataLoader
def CreateDataset(opt):
dataset = None
from data.aligned_dataset import AlignedDataset
dataset = AlignedDataset()
print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
sampler=data_sampler(self.dataset,
not opt.serial_batches, opt.distributed),
num_workers=int(opt.nThreads),
pin_memory=True)
def get_loader(self):
return self.dataloader
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
def data_sampler(dataset, shuffle, distributed):
if distributed:
return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
return torch.utils.data.RandomSampler(dataset)
else:
return torch.utils.data.SequentialSampler(dataset)
def sample_data(loader):
while True:
for batch in loader:
yield batch
class CustomTestDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
num_workers=int(opt.nThreads),
pin_memory=True)
def get_loader(self):
return self.dataloader
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
|