import argparse from net.dornet import Net from net.CR import * from data.rgbdd_dataloader import * from data.nyu_dataloader import * from utils import calc_rmse, rgbdd_calc_rmse from torch.utils.data import Dataset from torchvision import transforms, utils import torch import torch.optim as optim import torch.nn as nn from tqdm import tqdm import logging from datetime import datetime import os import numpy as np parser = argparse.ArgumentParser() parser.add_argument('--scale', type=int, default=4, help='scale factor') parser.add_argument('--lr', default='0.0001', type=float, help='learning rate') parser.add_argument('--result', default='experiment', help='learning rate') parser.add_argument('--tiny_model', action='store_true', help='tiny model') parser.add_argument('--epoch', default=300, type=int, help='max epoch') parser.add_argument("--decay_iterations", type=list, default=[1.2e5, 2e5, 3.6e5], help="steps to start lr decay") parser.add_argument("--gamma", type=float, default=0.2, help="decay rate of learning rate") parser.add_argument("--root_dir", type=str, default='./dataset/RGB-D-D', help="root dir of dataset") parser.add_argument("--batch_size", type=int, default=3, help="batch_size of training dataloader") parser.add_argument("--blur_sigma", type=int, default=3.6, help="blur_sigma") parser.add_argument('--isNoisy', action='store_true', help='Noisy') opt = parser.parse_args() print(opt) s = datetime.now().strftime('%Y%m%d%H%M%S') dataset_name = opt.root_dir.split('/')[-1] result_root = '%s/%s-lr_%s-s_%s-%s-b_%s' % (opt.result, s, opt.lr, opt.scale, dataset_name, opt.batch_size) if not os.path.exists(result_root): os.mkdir(result_root) logging.basicConfig(filename='%s/train.log' % result_root, format='%(asctime)s %(message)s', level=logging.INFO) logging.info(opt) net = Net(tiny_model=opt.tiny_model).cuda() print("**********************Parameters***********************") print(sum(p.numel() for p in net.parameters() if p.requires_grad)) print("**********************Parameters***********************") net.train() optimizer = optim.Adam(net.parameters(), lr=opt.lr) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.decay_iterations, gamma=opt.gamma) CL = ContrastLoss(ablation=False) l1 = nn.L1Loss().cuda() data_transform = transforms.Compose([transforms.ToTensor()]) if dataset_name == 'RGB-D-D': train_dataset = RGBDD_Dataset(root_dir=opt.root_dir, scale=opt.scale, downsample='real', train=True, transform=data_transform, isNoisy=opt.isNoisy, blur_sigma=opt.blur_sigma) test_dataset = RGBDD_Dataset(root_dir=opt.root_dir, scale=opt.scale, downsample='real', train=False, transform=data_transform, isNoisy=opt.isNoisy, blur_sigma=opt.blur_sigma) elif dataset_name == 'NYU-v2': test_minmax = np.load('%s/test_minmax.npy' % opt.root_dir) train_dataset = NYU_v2_datset(root_dir=opt.root_dir, scale=opt.scale, transform=data_transform, train=True) test_dataset = NYU_v2_datset(root_dir=opt.root_dir, scale=opt.scale, transform=data_transform, train=False) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=8) test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=8) max_epoch = opt.epoch num_train = len(train_dataloader) best_rmse = 100.0 best_epoch = 0 for epoch in range(max_epoch): # --------- # Training # --------- net.train() running_loss = 0.0 t = tqdm(iter(train_dataloader), leave=True, total=len(train_dataloader)) for idx, data in enumerate(t): batches_done = num_train * epoch + idx optimizer.zero_grad() guidance, lr, gt = data['guidance'].cuda(), data['lr'].cuda(), data['gt'].cuda() restored, d_lr_, aux_loss = net(x_query=lr, rgb=guidance) rec_loss = l1(restored, gt) da_loss = l1(d_lr_, lr) cl_loss = CL(d_lr_,lr,restored) loss = rec_loss + 0.1 * da_loss + 0.1 * cl_loss + aux_loss loss.backward() optimizer.step() scheduler.step() running_loss += loss.data.item() t.set_description( '[train epoch:%d] loss: Rec_loss:%.8f DA_loss:%.8f CL_loss:%.8f' % (epoch + 1, rec_loss.item(), da_loss.item(), cl_loss.item())) t.refresh() logging.info('epoch:%d iteration:%d running_loss:%.10f' % (epoch + 1, batches_done + 1, running_loss / num_train)) # ----------- # Validating # ----------- with torch.no_grad(): net.eval() if dataset_name == 'RGB-D-D': rmse = np.zeros(405) elif dataset_name == 'NYU-v2': rmse = np.zeros(449) t = tqdm(iter(test_dataloader), leave=True, total=len(test_dataloader)) for idx, data in enumerate(t): if dataset_name == 'RGB-D-D': guidance, lr, gt, max, min = data['guidance'].cuda(), data['lr'].cuda(), data['gt'].cuda(), data[ 'max'].cuda(), data['min'].cuda() out = net(x_query=lr, rgb=guidance) minmax = [max, min] rmse[idx] = rgbdd_calc_rmse(gt[0, 0], out[0, 0], minmax) t.set_description('[validate] rmse: %f' % rmse[:idx + 1].mean()) t.refresh() elif dataset_name == 'NYU-v2': guidance, lr, gt = data['guidance'].cuda(), data['lr'].cuda(), data['gt'].cuda() out = net(x_query=lr, rgb=guidance) minmax = test_minmax[:, idx] minmax = torch.from_numpy(minmax).cuda() rmse[idx] = calc_rmse(gt[0, 0], out[0, 0], minmax) t.set_description('[validate] rmse: %f' % rmse[:idx + 1].mean()) t.refresh() r_mean = rmse.mean() if r_mean < best_rmse: best_rmse = r_mean best_epoch = epoch torch.save(net.state_dict(), os.path.join(result_root, "RMSE%f_8%d.pth" % (best_rmse, best_epoch + 1))) logging.info( '---------------------------------------------------------------------------------------------------------------------------') logging.info('epoch:%d lr:%f-------mean_rmse:%f (BEST: %f @epoch%d)' % ( epoch + 1, scheduler.get_last_lr()[0], r_mean, best_rmse, best_epoch + 1)) logging.info( '---------------------------------------------------------------------------------------------------------------------------')