DenseSR / dataset.py
VanLinLin's picture
birth
3f7c489
import numpy as np
import os
from torch.utils.data import Dataset
import torch
from utils import load_normal, load_ssao, load_img, depthToPoint, process_normal, load_depth, Augment_RGB_torch
import torch.nn.functional as F
import random
augment = Augment_RGB_torch()
transforms_aug = [method for method in dir(augment) if callable(getattr(augment, method)) if not method.startswith('_')]
##################################################################################################
class DataLoaderTrain(Dataset):
def __init__(self, rgb_dir, img_options=None, target_transform=None, debug=False):
super(DataLoaderTrain, self).__init__()
self.target_transform = target_transform
gt_dir = 'shadow_free'
input_dir = 'origin'
depth_dir = 'depth'
normal_dir = 'normal'
clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir))) # shadow free
noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir))) # origin
depth_files = sorted(os.listdir(os.path.join(rgb_dir, depth_dir))) # depth
normal_files = sorted(os.listdir(os.path.join(rgb_dir, normal_dir))) # noraml map
self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files] # shadow free
self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files] # origin
self.depth_filenames = [os.path.join(rgb_dir, depth_dir, x) for x in depth_files] # depth
self.normal_filenames = [os.path.join(rgb_dir, normal_dir, x) for x in normal_files] # noraml map
self.img_options = img_options
if debug:
self.tar_size = 100
else:
self.tar_size = len(self.noisy_filenames)
def __len__(self):
return self.tar_size
def __getitem__(self, index):
tar_index = index % self.tar_size
clean = np.float32(load_img(self.clean_filenames[tar_index]))
noisy = np.float32(load_img(self.noisy_filenames[tar_index]))
depth = np.float32(load_depth(self.depth_filenames[tar_index]))
normal = np.float32(load_normal(self.normal_filenames[tar_index]))
point = depthToPoint(60, depth)
normal = process_normal(normal)
clean = torch.from_numpy(clean)
noisy = torch.from_numpy(noisy)
depth = torch.from_numpy(depth)
point = torch.from_numpy(point)
normal = torch.from_numpy(normal)
point = point / (2 * point[:,:,2].mean())
clean = clean.permute(2,0,1)
noisy = noisy.permute(2,0,1)
point = point.permute(2,0,1)
normal = normal.permute(2,0,1)
clean_filename = os.path.split(self.clean_filenames[tar_index])[-1]
noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1]
depth_filename = os.path.split(self.depth_filenames[tar_index])[-1]
normal_filename = os.path.split(self.normal_filenames[tar_index])[-1]
augment.rotate = random.randint(-20,20)
apply_trans = transforms_aug[random.randint(0, 2)]
# [0, 1]
clean = getattr(augment, apply_trans)(clean)
noisy = getattr(augment, apply_trans)(noisy)
point = getattr(augment, apply_trans)(point)
normal = getattr(augment, apply_trans)(normal)
#Crop Input and Target
ps = self.img_options['patch_size']
scale = 1#random.uniform(1, 1.5)
H = noisy.shape[1]
W = noisy.shape[2]
scaled_ps = (int)(scale * ps)
if H - scaled_ps != 0 or W - scaled_ps != 0:
r = np.random.randint(0, H - scaled_ps + 1)
c = np.random.randint(0, W - scaled_ps + 1)
clean = clean [:, r:r + scaled_ps, c:c + scaled_ps]
noisy = noisy [:, r:r + scaled_ps, c:c + scaled_ps]
point = point [:, r:r + scaled_ps, c:c + scaled_ps]
normal = normal [:, r:r + scaled_ps, c:c + scaled_ps]
# scale back to the patch_size
if scale != 1:
clean = F.interpolate(clean.unsqueeze(0), size=[ps, ps], mode='bilinear')
noisy = F.interpolate(noisy.unsqueeze(0), size=[ps, ps], mode='bilinear')
point = F.interpolate(point.unsqueeze(0), size=[ps, ps], mode='nearest')
normal = F.interpolate(normal.unsqueeze(0), size=[ps, ps], mode='nearest')
return clean.squeeze(0), noisy.squeeze(0), point.squeeze(0), normal.squeeze(0), noisy_filename
return clean, noisy, point, normal, clean_filename, noisy_filename
##################################################################################################
class DataLoaderVal(Dataset):
def __init__(self, rgb_dir, target_transform=None, debug=False):
super(DataLoaderVal, self).__init__()
self.target_transform = target_transform
gt_dir = 'shadow_free'
input_dir = 'origin'
depth_dir = 'depth'
normal_dir = 'normal'
clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir)))
noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir)))
depth_files = sorted(os.listdir(os.path.join(rgb_dir, depth_dir)))
normal_files = sorted(os.listdir(os.path.join(rgb_dir, normal_dir)))
self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files]
self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files]
self.depth_filenames = [os.path.join(rgb_dir, depth_dir, x) for x in depth_files]
self.normal_filenames = [os.path.join(rgb_dir, normal_dir, x) for x in normal_files]
if debug:
self.tar_size = 10
else:
self.tar_size = len(self.noisy_filenames)
def __len__(self):
return self.tar_size
def __getitem__(self, index):
tar_index = index % self.tar_size
clean = np.float32(load_img(self.clean_filenames[tar_index]))
noisy = np.float32(load_img(self.noisy_filenames[tar_index]))
depth = np.float32(load_depth(self.depth_filenames[tar_index]))
normal = np.float32(load_normal(self.normal_filenames[tar_index]))
point = depthToPoint(60, depth)
normal = process_normal(normal)
point = point / (2 * point[:,:,2].mean())
clean_filename = os.path.split(self.clean_filenames[tar_index])[-1]
noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1]
clean = torch.from_numpy(clean)
noisy = torch.from_numpy(noisy)
point = torch.from_numpy(point)
normal = torch.from_numpy(normal)
clean = clean.permute(2,0,1)
noisy = noisy.permute(2,0,1)
point = point.permute(2,0,1)
normal = normal.permute(2,0,1)
return clean, noisy, point, normal, clean_filename, noisy_filename
##################################################################################################
class DataLoaderTest(Dataset):
def __init__(self, rgb_dir, target_transform=None, debug=False):
super(DataLoaderTest, self).__init__()
self.target_transform = target_transform
# gt_dir = 'shadow_free'
input_dir = 'origin'
depth_dir = 'depth'
normal_dir = 'normal'
# clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir)))
noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir)))
depth_files = sorted(os.listdir(os.path.join(rgb_dir, depth_dir)))
normal_files = sorted(os.listdir(os.path.join(rgb_dir, normal_dir)))
# self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files]
self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files]
self.depth_filenames = [os.path.join(rgb_dir, depth_dir, x) for x in depth_files]
self.normal_filenames = [os.path.join(rgb_dir, normal_dir, x) for x in normal_files]
if debug:
self.tar_size = 10
else:
self.tar_size = len(self.noisy_filenames)
def __len__(self):
return self.tar_size
def __getitem__(self, index):
tar_index = index % self.tar_size
# clean = np.float32(load_img(self.clean_filenames[tar_index]))
noisy = np.float32(load_img(self.noisy_filenames[tar_index]))
depth = np.float32(load_depth(self.depth_filenames[tar_index]))
normal = np.float32(load_normal(self.normal_filenames[tar_index]))
point = depthToPoint(60, depth)
normal = process_normal(normal)
point = point / (2 * point[:,:,2].mean())
# clean_filename = os.path.split(self.clean_filenames[tar_index])[-1]
noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1]
# clean = torch.from_numpy(clean)
noisy = torch.from_numpy(noisy)
point = torch.from_numpy(point)
normal = torch.from_numpy(normal)
# clean = clean.permute(2,0,1)
noisy = noisy.permute(2,0,1)
point = point.permute(2,0,1)
normal = normal.permute(2,0,1)
return noisy, noisy, point, normal, noisy_filename, noisy_filename