DORNet / data /nyu_dataloader.py
RaynWu2002's picture
Upload 49 files
68c537d verified
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
class NYU_v2_datset(Dataset):
"""NYUDataset."""
def __init__(self, root_dir, scale=8, train=True, transform=None):
"""
Args:
root_dir (string): Directory with all the images.
scale (float): dataset scale
train (bool): train or test
transform (callable, optional): Optional transform to be applied on a sample.
"""
self.root_dir = root_dir
self.transform = transform
self.scale = scale
self.train = train
if train:
self.depths = np.load('%s/train_depth_split.npy' % root_dir)
self.images = np.load('%s/train_images_split.npy' % root_dir)
else:
self.depths = np.load('%s/test_depth.npy' % root_dir)
self.images = np.load('%s/test_images_v2.npy' % root_dir)
def __len__(self):
return self.depths.shape[0]
def __getitem__(self, idx):
depth = self.depths[idx]
image = self.images[idx]
h, w = depth.shape[:2]
s = self.scale
lr = np.array(Image.fromarray(depth.squeeze()).resize((w // s, h // s), Image.BICUBIC).resize((w, h), Image.BICUBIC))
if self.transform:
image = self.transform(image).float()
depth = self.transform(depth).float()
lr = self.transform(np.expand_dims(lr, 2)).float()
sample = {'guidance': image, 'lr': lr, 'gt': depth}
return sample