|
import numpy as np |
|
import os |
|
from torch.utils.data import Dataset |
|
import torch |
|
from utils import load_normal, load_ssao, load_img, depthToPoint, process_normal, load_depth, Augment_RGB_torch |
|
import torch.nn.functional as F |
|
import random |
|
|
|
augment = Augment_RGB_torch() |
|
transforms_aug = [method for method in dir(augment) if callable(getattr(augment, method)) if not method.startswith('_')] |
|
|
|
|
|
class DataLoaderTrain(Dataset): |
|
def __init__(self, rgb_dir, img_options=None, target_transform=None, debug=False): |
|
super(DataLoaderTrain, self).__init__() |
|
|
|
self.target_transform = target_transform |
|
|
|
gt_dir = 'shadow_free' |
|
input_dir = 'origin' |
|
depth_dir = 'depth' |
|
normal_dir = 'normal' |
|
|
|
clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir))) |
|
noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir))) |
|
depth_files = sorted(os.listdir(os.path.join(rgb_dir, depth_dir))) |
|
normal_files = sorted(os.listdir(os.path.join(rgb_dir, normal_dir))) |
|
|
|
self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files] |
|
self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files] |
|
self.depth_filenames = [os.path.join(rgb_dir, depth_dir, x) for x in depth_files] |
|
self.normal_filenames = [os.path.join(rgb_dir, normal_dir, x) for x in normal_files] |
|
self.img_options = img_options |
|
|
|
if debug: |
|
self.tar_size = 100 |
|
else: |
|
self.tar_size = len(self.noisy_filenames) |
|
|
|
def __len__(self): |
|
return self.tar_size |
|
|
|
def __getitem__(self, index): |
|
tar_index = index % self.tar_size |
|
|
|
clean = np.float32(load_img(self.clean_filenames[tar_index])) |
|
noisy = np.float32(load_img(self.noisy_filenames[tar_index])) |
|
depth = np.float32(load_depth(self.depth_filenames[tar_index])) |
|
normal = np.float32(load_normal(self.normal_filenames[tar_index])) |
|
|
|
point = depthToPoint(60, depth) |
|
|
|
normal = process_normal(normal) |
|
|
|
clean = torch.from_numpy(clean) |
|
noisy = torch.from_numpy(noisy) |
|
depth = torch.from_numpy(depth) |
|
point = torch.from_numpy(point) |
|
normal = torch.from_numpy(normal) |
|
|
|
point = point / (2 * point[:,:,2].mean()) |
|
|
|
clean = clean.permute(2,0,1) |
|
noisy = noisy.permute(2,0,1) |
|
point = point.permute(2,0,1) |
|
normal = normal.permute(2,0,1) |
|
|
|
clean_filename = os.path.split(self.clean_filenames[tar_index])[-1] |
|
noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1] |
|
depth_filename = os.path.split(self.depth_filenames[tar_index])[-1] |
|
normal_filename = os.path.split(self.normal_filenames[tar_index])[-1] |
|
|
|
|
|
augment.rotate = random.randint(-20,20) |
|
apply_trans = transforms_aug[random.randint(0, 2)] |
|
|
|
|
|
clean = getattr(augment, apply_trans)(clean) |
|
noisy = getattr(augment, apply_trans)(noisy) |
|
point = getattr(augment, apply_trans)(point) |
|
normal = getattr(augment, apply_trans)(normal) |
|
|
|
|
|
|
|
ps = self.img_options['patch_size'] |
|
scale = 1 |
|
|
|
H = noisy.shape[1] |
|
W = noisy.shape[2] |
|
scaled_ps = (int)(scale * ps) |
|
if H - scaled_ps != 0 or W - scaled_ps != 0: |
|
r = np.random.randint(0, H - scaled_ps + 1) |
|
c = np.random.randint(0, W - scaled_ps + 1) |
|
clean = clean [:, r:r + scaled_ps, c:c + scaled_ps] |
|
noisy = noisy [:, r:r + scaled_ps, c:c + scaled_ps] |
|
point = point [:, r:r + scaled_ps, c:c + scaled_ps] |
|
normal = normal [:, r:r + scaled_ps, c:c + scaled_ps] |
|
|
|
|
|
if scale != 1: |
|
clean = F.interpolate(clean.unsqueeze(0), size=[ps, ps], mode='bilinear') |
|
noisy = F.interpolate(noisy.unsqueeze(0), size=[ps, ps], mode='bilinear') |
|
point = F.interpolate(point.unsqueeze(0), size=[ps, ps], mode='nearest') |
|
normal = F.interpolate(normal.unsqueeze(0), size=[ps, ps], mode='nearest') |
|
return clean.squeeze(0), noisy.squeeze(0), point.squeeze(0), normal.squeeze(0), noisy_filename |
|
|
|
return clean, noisy, point, normal, clean_filename, noisy_filename |
|
|
|
|
|
|
|
class DataLoaderVal(Dataset): |
|
def __init__(self, rgb_dir, target_transform=None, debug=False): |
|
super(DataLoaderVal, self).__init__() |
|
|
|
self.target_transform = target_transform |
|
|
|
gt_dir = 'shadow_free' |
|
input_dir = 'origin' |
|
depth_dir = 'depth' |
|
normal_dir = 'normal' |
|
|
|
clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir))) |
|
noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir))) |
|
depth_files = sorted(os.listdir(os.path.join(rgb_dir, depth_dir))) |
|
normal_files = sorted(os.listdir(os.path.join(rgb_dir, normal_dir))) |
|
|
|
self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files] |
|
self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files] |
|
self.depth_filenames = [os.path.join(rgb_dir, depth_dir, x) for x in depth_files] |
|
self.normal_filenames = [os.path.join(rgb_dir, normal_dir, x) for x in normal_files] |
|
|
|
if debug: |
|
self.tar_size = 10 |
|
else: |
|
self.tar_size = len(self.noisy_filenames) |
|
|
|
def __len__(self): |
|
return self.tar_size |
|
|
|
def __getitem__(self, index): |
|
tar_index = index % self.tar_size |
|
clean = np.float32(load_img(self.clean_filenames[tar_index])) |
|
noisy = np.float32(load_img(self.noisy_filenames[tar_index])) |
|
depth = np.float32(load_depth(self.depth_filenames[tar_index])) |
|
normal = np.float32(load_normal(self.normal_filenames[tar_index])) |
|
|
|
point = depthToPoint(60, depth) |
|
normal = process_normal(normal) |
|
point = point / (2 * point[:,:,2].mean()) |
|
|
|
clean_filename = os.path.split(self.clean_filenames[tar_index])[-1] |
|
noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1] |
|
|
|
clean = torch.from_numpy(clean) |
|
noisy = torch.from_numpy(noisy) |
|
point = torch.from_numpy(point) |
|
normal = torch.from_numpy(normal) |
|
|
|
|
|
clean = clean.permute(2,0,1) |
|
noisy = noisy.permute(2,0,1) |
|
point = point.permute(2,0,1) |
|
normal = normal.permute(2,0,1) |
|
|
|
|
|
return clean, noisy, point, normal, clean_filename, noisy_filename |
|
|
|
|
|
|
|
|
|
|
|
class DataLoaderTest(Dataset): |
|
def __init__(self, rgb_dir, target_transform=None, debug=False): |
|
super(DataLoaderTest, self).__init__() |
|
|
|
self.target_transform = target_transform |
|
|
|
|
|
input_dir = 'origin' |
|
depth_dir = 'depth' |
|
normal_dir = 'normal' |
|
|
|
|
|
noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir))) |
|
depth_files = sorted(os.listdir(os.path.join(rgb_dir, depth_dir))) |
|
normal_files = sorted(os.listdir(os.path.join(rgb_dir, normal_dir))) |
|
|
|
|
|
self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files] |
|
self.depth_filenames = [os.path.join(rgb_dir, depth_dir, x) for x in depth_files] |
|
self.normal_filenames = [os.path.join(rgb_dir, normal_dir, x) for x in normal_files] |
|
|
|
if debug: |
|
self.tar_size = 10 |
|
else: |
|
self.tar_size = len(self.noisy_filenames) |
|
|
|
def __len__(self): |
|
return self.tar_size |
|
|
|
def __getitem__(self, index): |
|
tar_index = index % self.tar_size |
|
|
|
noisy = np.float32(load_img(self.noisy_filenames[tar_index])) |
|
depth = np.float32(load_depth(self.depth_filenames[tar_index])) |
|
normal = np.float32(load_normal(self.normal_filenames[tar_index])) |
|
|
|
point = depthToPoint(60, depth) |
|
normal = process_normal(normal) |
|
point = point / (2 * point[:,:,2].mean()) |
|
|
|
|
|
noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1] |
|
|
|
|
|
noisy = torch.from_numpy(noisy) |
|
point = torch.from_numpy(point) |
|
normal = torch.from_numpy(normal) |
|
|
|
|
|
|
|
noisy = noisy.permute(2,0,1) |
|
point = point.permute(2,0,1) |
|
normal = normal.permute(2,0,1) |
|
|
|
|
|
return noisy, noisy, point, normal, noisy_filename, noisy_filename |
|
|
|
|