ThunderVVV's picture
add thirdparty
b7eedf7
raw
history blame
6.38 kB
import sys
sys.path.append('droid_slam')
import cv2
import numpy as np
from collections import OrderedDict
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from data_readers.factory import dataset_factory
from lietorch import SO3, SE3, Sim3
from geom import losses
from geom.losses import geodesic_loss, residual_loss, flow_loss
from geom.graph_utils import build_frame_graph
# network
from droid_net import DroidNet
from logger import Logger
# DDP training
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_ddp(gpu, args):
dist.init_process_group(
backend='nccl',
init_method='env://',
world_size=args.world_size,
rank=gpu)
torch.manual_seed(0)
torch.cuda.set_device(gpu)
def show_image(image):
image = image.permute(1, 2, 0).cpu().numpy()
cv2.imshow('image', image / 255.0)
cv2.waitKey()
def train(gpu, args):
""" Test to make sure project transform correctly maps points """
# coordinate multiple GPUs
setup_ddp(gpu, args)
rng = np.random.default_rng(12345)
N = args.n_frames
model = DroidNet()
model.cuda()
model.train()
model = DDP(model, device_ids=[gpu], find_unused_parameters=False)
if args.ckpt is not None:
model.load_state_dict(torch.load(args.ckpt))
# fetch dataloader
db = dataset_factory(['tartan'], datapath=args.datapath, n_frames=args.n_frames, fmin=args.fmin, fmax=args.fmax)
train_sampler = torch.utils.data.distributed.DistributedSampler(
db, shuffle=True, num_replicas=args.world_size, rank=gpu)
train_loader = DataLoader(db, batch_size=args.batch, sampler=train_sampler, num_workers=2)
# fetch optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
args.lr, args.steps, pct_start=0.01, cycle_momentum=False)
logger = Logger(args.name, scheduler)
should_keep_training = True
total_steps = 0
while should_keep_training:
for i_batch, item in enumerate(train_loader):
optimizer.zero_grad()
images, poses, disps, intrinsics = [x.to('cuda') for x in item]
# convert poses w2c -> c2w
Ps = SE3(poses).inv()
Gs = SE3.IdentityLike(Ps)
# randomize frame graph
if np.random.rand() < 0.5:
graph = build_frame_graph(poses, disps, intrinsics, num=args.edges)
else:
graph = OrderedDict()
for i in range(N):
graph[i] = [j for j in range(N) if i!=j and abs(i-j) <= 2]
# fix first to camera poses
Gs.data[:,0] = Ps.data[:,0].clone()
Gs.data[:,1:] = Ps.data[:,[1]].clone()
disp0 = torch.ones_like(disps[:,:,3::8,3::8])
# perform random restarts
r = 0
while r < args.restart_prob:
r = rng.random()
intrinsics0 = intrinsics / 8.0
poses_est, disps_est, residuals = model(Gs, images, disp0, intrinsics0,
graph, num_steps=args.iters, fixedp=2)
geo_loss, geo_metrics = losses.geodesic_loss(Ps, poses_est, graph, do_scale=False)
res_loss, res_metrics = losses.residual_loss(residuals)
flo_loss, flo_metrics = losses.flow_loss(Ps, disps, poses_est, disps_est, intrinsics, graph)
loss = args.w1 * geo_loss + args.w2 * res_loss + args.w3 * flo_loss
loss.backward()
Gs = poses_est[-1].detach()
disp0 = disps_est[-1][:,:,3::8,3::8].detach()
metrics = {}
metrics.update(geo_metrics)
metrics.update(res_metrics)
metrics.update(flo_metrics)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()
scheduler.step()
total_steps += 1
if gpu == 0:
logger.push(metrics)
if total_steps % 10000 == 0 and gpu == 0:
PATH = 'checkpoints/%s_%06d.pth' % (args.name, total_steps)
torch.save(model.state_dict(), PATH)
if total_steps >= args.steps:
should_keep_training = False
break
dist.destroy_process_group()
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--name', default='bla', help='name your experiment')
parser.add_argument('--ckpt', help='checkpoint to restore')
parser.add_argument('--datasets', nargs='+', help='lists of datasets for training')
parser.add_argument('--datapath', default='datasets/TartanAir', help="path to dataset directory")
parser.add_argument('--gpus', type=int, default=4)
parser.add_argument('--batch', type=int, default=1)
parser.add_argument('--iters', type=int, default=15)
parser.add_argument('--steps', type=int, default=250000)
parser.add_argument('--lr', type=float, default=0.00025)
parser.add_argument('--clip', type=float, default=2.5)
parser.add_argument('--n_frames', type=int, default=7)
parser.add_argument('--w1', type=float, default=10.0)
parser.add_argument('--w2', type=float, default=0.01)
parser.add_argument('--w3', type=float, default=0.05)
parser.add_argument('--fmin', type=float, default=8.0)
parser.add_argument('--fmax', type=float, default=96.0)
parser.add_argument('--noise', action='store_true')
parser.add_argument('--scale', action='store_true')
parser.add_argument('--edges', type=int, default=24)
parser.add_argument('--restart_prob', type=float, default=0.2)
args = parser.parse_args()
args.world_size = args.gpus
print(args)
import os
if not os.path.isdir('checkpoints'):
os.mkdir('checkpoints')
args = parser.parse_args()
args.world_size = args.gpus
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12356'
mp.spawn(train, nprocs=args.gpus, args=(args,))