|
import numpy as np |
|
import os |
|
import argparse |
|
from tqdm import tqdm |
|
from torch.utils.data.distributed import DistributedSampler |
|
import torch.nn as nn |
|
import torch |
|
from torch.utils.data import DataLoader |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
import torch.nn.functional as F |
|
import random |
|
from utils.loader import get_test_data |
|
import utils |
|
import torch.distributed as dist |
|
from skimage.metrics import peak_signal_noise_ratio as psnr_loss |
|
from skimage.metrics import structural_similarity as ssim_loss |
|
parser = argparse.ArgumentParser(description='RGB denoising evaluation on the validation set of SIDD') |
|
parser.add_argument('--input_dir', default='test_dir', |
|
type=str, help='Directory of validation images') |
|
parser.add_argument('--result_dir', default='./output_dir', |
|
type=str, help='Directory for results') |
|
parser.add_argument('--weights', default='best_WSRD.pth' |
|
,type=str, help='Path to weights') |
|
parser.add_argument('--arch', type=str, default='DenseSR', help='archtechture') |
|
parser.add_argument('--batch_size', default=1, type=int, help='Batch size for dataloader') |
|
parser.add_argument('--save_images', action='store_true', default=False, help='Save denoised images in result directory') |
|
parser.add_argument('--cal_metrics', action='store_true', default=False, help='Measure denoised images with GT') |
|
parser.add_argument('--embed_dim', type=int, default=32, help='number of data loading workers') |
|
parser.add_argument('--win_size', type=int, default=16, help='number of data loading workers') |
|
parser.add_argument('--token_projection', type=str, default='linear', help='linear/conv token projection') |
|
parser.add_argument('--token_mlp', type=str,default='leff', help='ffn/leff token mlp') |
|
|
|
parser.add_argument('--train_ps', type=int, default=256, help='patch size of training sample') |
|
parser.add_argument("--local-rank", type=int) |
|
|
|
args = parser.parse_args() |
|
|
|
local_rank = args.local_rank |
|
torch.cuda.set_device(local_rank) |
|
dist.init_process_group(backend='nccl') |
|
device = torch.device("cuda", local_rank) |
|
|
|
|
|
class SlidingWindowInference: |
|
def __init__(self, window_size=512, overlap=64, img_multiple_of=64): |
|
self.window_size = window_size |
|
self.overlap = overlap |
|
self.img_multiple_of = img_multiple_of |
|
|
|
def _pad_input(self, x, h_pad, w_pad): |
|
return F.pad(x, (0, w_pad, 0, h_pad), 'reflect') |
|
|
|
def __call__(self, model, input_, point, normal, dino_net, device): |
|
original_height, original_width = input_.shape[2], input_.shape[3] |
|
|
|
H = max(self.window_size, |
|
((original_height + self.img_multiple_of - 1) // self.img_multiple_of) * self.img_multiple_of) |
|
W = max(self.window_size, |
|
((original_width + self.img_multiple_of - 1) // self.img_multiple_of) * self.img_multiple_of) |
|
|
|
padh = H - original_height |
|
padw = W - original_width |
|
|
|
|
|
input_pad = self._pad_input(input_, padh, padw) |
|
point_pad = self._pad_input(point, padh, padw) |
|
normal_pad = self._pad_input(normal, padh, padw) |
|
|
|
if original_height <= self.window_size and original_width <= self.window_size: |
|
|
|
DINO_patch_size = 14 |
|
h_size = H * DINO_patch_size // 8 |
|
w_size = W * DINO_patch_size // 8 |
|
|
|
UpSample_window = torch.nn.UpsamplingBilinear2d(size=(h_size, w_size)) |
|
|
|
with torch.no_grad(): |
|
input_DINO = UpSample_window(input_pad) |
|
dino_features = dino_net.module.get_intermediate_layers(input_DINO, 4, True) |
|
|
|
|
|
with torch.amp.autocast(device_type='cuda'): |
|
restored = model(input_pad, dino_features, point_pad, normal_pad) |
|
|
|
|
|
output = restored[:, :, :original_height, :original_width] |
|
return output |
|
|
|
|
|
stride = self.window_size - self.overlap |
|
h_steps = (H - self.window_size + stride - 1) // stride + 1 |
|
w_steps = (W - self.window_size + stride - 1) // stride + 1 |
|
|
|
|
|
output = torch.zeros_like(input_pad) |
|
count = torch.zeros_like(input_pad) |
|
|
|
for h_idx in range(h_steps): |
|
for w_idx in range(w_steps): |
|
|
|
h_start = min(h_idx * stride, H - self.window_size) |
|
w_start = min(w_idx * stride, W - self.window_size) |
|
h_end = h_start + self.window_size |
|
w_end = w_start + self.window_size |
|
|
|
|
|
input_window = input_pad[:, :, h_start:h_end, w_start:w_end] |
|
point_window = point_pad[:, :, h_start:h_end, w_start:w_end] |
|
normal_window = normal_pad[:, :, h_start:h_end, w_start:w_end] |
|
|
|
|
|
DINO_patch_size = 14 |
|
h_size = self.window_size * DINO_patch_size // 8 |
|
w_size = self.window_size * DINO_patch_size // 8 |
|
|
|
UpSample_window = torch.nn.UpsamplingBilinear2d(size=(h_size, w_size)) |
|
|
|
|
|
with torch.no_grad(): |
|
input_DINO = UpSample_window(input_window) |
|
dino_features = dino_net.module.get_intermediate_layers(input_DINO, 4, True) |
|
|
|
|
|
with torch.amp.autocast(device_type='cuda'): |
|
restored = model(input_window, dino_features, point_window, normal_window) |
|
|
|
|
|
weight = torch.ones_like(restored) |
|
if self.overlap > 0: |
|
|
|
for i in range(self.overlap): |
|
ratio = i / self.overlap |
|
weight[:, :, i, :] *= ratio |
|
weight[:, :, -(i+1), :] *= ratio |
|
weight[:, :, :, i] *= ratio |
|
weight[:, :, :, -(i+1)] *= ratio |
|
|
|
|
|
output[:, :, h_start:h_end, w_start:w_end] += restored * weight |
|
count[:, :, h_start:h_end, w_start:w_end] += weight |
|
|
|
|
|
output = output / (count + 1e-6) |
|
|
|
|
|
output = output[:, :, :original_height, :original_width] |
|
return output |
|
|
|
|
|
utils.mkdir(args.result_dir) |
|
|
|
|
|
random.seed(1234) |
|
np.random.seed(1234) |
|
torch.manual_seed(1234) |
|
torch.cuda.manual_seed(1234) |
|
torch.cuda.manual_seed_all(1234) |
|
|
|
def worker_init_fn(worker_id): |
|
random.seed(1234 + worker_id) |
|
|
|
g = torch.Generator() |
|
g.manual_seed(1234) |
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
model_restoration = utils.get_arch(args) |
|
model_restoration.to(device) |
|
model_restoration.eval() |
|
DINO_Net = torch.hub.load('./dinov2', 'dinov2_vitl14', source='local') |
|
DINO_Net.to(device) |
|
DINO_Net.eval() |
|
|
|
utils.load_checkpoint(model_restoration, args.weights) |
|
print("===>Testing using weights: ", args.weights) |
|
|
|
|
|
|
|
model_restoration = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_restoration).to(device) |
|
model_restoration = DDP(model_restoration, device_ids=[local_rank], output_device=local_rank) |
|
DINO_Net = DDP(DINO_Net, device_ids=[local_rank], output_device=local_rank) |
|
|
|
|
|
img_multiple_of = 8 * args.win_size |
|
DINO_patch_size = 14 |
|
|
|
def UpSample(img): |
|
upsample = nn.UpsamplingBilinear2d( |
|
size=((int)(img.shape[2] * (DINO_patch_size / 8)), |
|
(int)(img.shape[3] * (DINO_patch_size / 8)))) |
|
return upsample(img) |
|
|
|
img_options_train = {'patch_size':args.train_ps} |
|
test_dataset = get_test_data(args.input_dir, False) |
|
test_sampler = DistributedSampler(test_dataset, shuffle=False) |
|
test_loader = DataLoader(dataset=test_dataset, batch_size=1, num_workers=0, sampler=test_sampler, drop_last=False, worker_init_fn=worker_init_fn, generator=g) |
|
with torch.no_grad(): |
|
psnr_val_rgb_list = [] |
|
psnr_val_mask_list = [] |
|
ssim_val_rgb_list = [] |
|
rmse_val_rgb_list = [] |
|
for ii, data_test in enumerate(tqdm(test_loader), 0): |
|
rgb_noisy = data_test[1].to(device) |
|
point = data_test[2].to(device) |
|
normal = data_test[3].to(device) |
|
filenames = data_test[4] |
|
|
|
|
|
sliding_window = SlidingWindowInference( |
|
window_size=512, |
|
overlap=64, |
|
img_multiple_of=8 * args.win_size |
|
) |
|
|
|
with torch.amp.autocast(device_type='cuda'): |
|
rgb_restored = sliding_window( |
|
model=model_restoration, |
|
input_=rgb_noisy, |
|
point=point, |
|
normal=normal, |
|
dino_net=DINO_Net, |
|
device=device |
|
) |
|
|
|
|
|
rgb_restored = torch.clamp(rgb_restored, 0.0, 1.0) |
|
rgb_restored = torch.clamp(rgb_restored, 0, 1).cpu().numpy().squeeze().transpose((1, 2, 0)) |
|
|
|
|
|
if args.save_images: |
|
utils.save_img(rgb_restored * 255.0, os.path.join(args.result_dir, filenames[0])) |
|
|
|
|
|
|