Spaces:
Sleeping
Sleeping
from torch.utils.data import Dataset | |
from PIL import Image | |
import numpy as np | |
class NYU_v2_datset(Dataset): | |
"""NYUDataset.""" | |
def __init__(self, root_dir, scale=8, train=True, transform=None): | |
""" | |
Args: | |
root_dir (string): Directory with all the images. | |
scale (float): dataset scale | |
train (bool): train or test | |
transform (callable, optional): Optional transform to be applied on a sample. | |
""" | |
self.root_dir = root_dir | |
self.transform = transform | |
self.scale = scale | |
self.train = train | |
if train: | |
self.depths = np.load('%s/train_depth_split.npy' % root_dir) | |
self.images = np.load('%s/train_images_split.npy' % root_dir) | |
else: | |
self.depths = np.load('%s/test_depth.npy' % root_dir) | |
self.images = np.load('%s/test_images_v2.npy' % root_dir) | |
def __len__(self): | |
return self.depths.shape[0] | |
def __getitem__(self, idx): | |
depth = self.depths[idx] | |
image = self.images[idx] | |
h, w = depth.shape[:2] | |
s = self.scale | |
lr = np.array(Image.fromarray(depth.squeeze()).resize((w // s, h // s), Image.BICUBIC).resize((w, h), Image.BICUBIC)) | |
if self.transform: | |
image = self.transform(image).float() | |
depth = self.transform(depth).float() | |
lr = self.transform(np.expand_dims(lr, 2)).float() | |
sample = {'guidance': image, 'lr': lr, 'gt': depth} | |
return sample |