Spaces:
Running
Running
import sys | |
sys.path.append('droid_slam') | |
import cv2 | |
import numpy as np | |
from collections import OrderedDict | |
import torch | |
import torch.optim as optim | |
from torch.utils.data import DataLoader | |
from data_readers.factory import dataset_factory | |
from lietorch import SO3, SE3, Sim3 | |
from geom import losses | |
from geom.losses import geodesic_loss, residual_loss, flow_loss | |
from geom.graph_utils import build_frame_graph | |
# network | |
from droid_net import DroidNet | |
from logger import Logger | |
# DDP training | |
import torch.multiprocessing as mp | |
import torch.distributed as dist | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
def setup_ddp(gpu, args): | |
dist.init_process_group( | |
backend='nccl', | |
init_method='env://', | |
world_size=args.world_size, | |
rank=gpu) | |
torch.manual_seed(0) | |
torch.cuda.set_device(gpu) | |
def show_image(image): | |
image = image.permute(1, 2, 0).cpu().numpy() | |
cv2.imshow('image', image / 255.0) | |
cv2.waitKey() | |
def train(gpu, args): | |
""" Test to make sure project transform correctly maps points """ | |
# coordinate multiple GPUs | |
setup_ddp(gpu, args) | |
rng = np.random.default_rng(12345) | |
N = args.n_frames | |
model = DroidNet() | |
model.cuda() | |
model.train() | |
model = DDP(model, device_ids=[gpu], find_unused_parameters=False) | |
if args.ckpt is not None: | |
model.load_state_dict(torch.load(args.ckpt)) | |
# fetch dataloader | |
db = dataset_factory(['tartan'], datapath=args.datapath, n_frames=args.n_frames, fmin=args.fmin, fmax=args.fmax) | |
train_sampler = torch.utils.data.distributed.DistributedSampler( | |
db, shuffle=True, num_replicas=args.world_size, rank=gpu) | |
train_loader = DataLoader(db, batch_size=args.batch, sampler=train_sampler, num_workers=2) | |
# fetch optimizer | |
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) | |
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, | |
args.lr, args.steps, pct_start=0.01, cycle_momentum=False) | |
logger = Logger(args.name, scheduler) | |
should_keep_training = True | |
total_steps = 0 | |
while should_keep_training: | |
for i_batch, item in enumerate(train_loader): | |
optimizer.zero_grad() | |
images, poses, disps, intrinsics = [x.to('cuda') for x in item] | |
# convert poses w2c -> c2w | |
Ps = SE3(poses).inv() | |
Gs = SE3.IdentityLike(Ps) | |
# randomize frame graph | |
if np.random.rand() < 0.5: | |
graph = build_frame_graph(poses, disps, intrinsics, num=args.edges) | |
else: | |
graph = OrderedDict() | |
for i in range(N): | |
graph[i] = [j for j in range(N) if i!=j and abs(i-j) <= 2] | |
# fix first to camera poses | |
Gs.data[:,0] = Ps.data[:,0].clone() | |
Gs.data[:,1:] = Ps.data[:,[1]].clone() | |
disp0 = torch.ones_like(disps[:,:,3::8,3::8]) | |
# perform random restarts | |
r = 0 | |
while r < args.restart_prob: | |
r = rng.random() | |
intrinsics0 = intrinsics / 8.0 | |
poses_est, disps_est, residuals = model(Gs, images, disp0, intrinsics0, | |
graph, num_steps=args.iters, fixedp=2) | |
geo_loss, geo_metrics = losses.geodesic_loss(Ps, poses_est, graph, do_scale=False) | |
res_loss, res_metrics = losses.residual_loss(residuals) | |
flo_loss, flo_metrics = losses.flow_loss(Ps, disps, poses_est, disps_est, intrinsics, graph) | |
loss = args.w1 * geo_loss + args.w2 * res_loss + args.w3 * flo_loss | |
loss.backward() | |
Gs = poses_est[-1].detach() | |
disp0 = disps_est[-1][:,:,3::8,3::8].detach() | |
metrics = {} | |
metrics.update(geo_metrics) | |
metrics.update(res_metrics) | |
metrics.update(flo_metrics) | |
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) | |
optimizer.step() | |
scheduler.step() | |
total_steps += 1 | |
if gpu == 0: | |
logger.push(metrics) | |
if total_steps % 10000 == 0 and gpu == 0: | |
PATH = 'checkpoints/%s_%06d.pth' % (args.name, total_steps) | |
torch.save(model.state_dict(), PATH) | |
if total_steps >= args.steps: | |
should_keep_training = False | |
break | |
dist.destroy_process_group() | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--name', default='bla', help='name your experiment') | |
parser.add_argument('--ckpt', help='checkpoint to restore') | |
parser.add_argument('--datasets', nargs='+', help='lists of datasets for training') | |
parser.add_argument('--datapath', default='datasets/TartanAir', help="path to dataset directory") | |
parser.add_argument('--gpus', type=int, default=4) | |
parser.add_argument('--batch', type=int, default=1) | |
parser.add_argument('--iters', type=int, default=15) | |
parser.add_argument('--steps', type=int, default=250000) | |
parser.add_argument('--lr', type=float, default=0.00025) | |
parser.add_argument('--clip', type=float, default=2.5) | |
parser.add_argument('--n_frames', type=int, default=7) | |
parser.add_argument('--w1', type=float, default=10.0) | |
parser.add_argument('--w2', type=float, default=0.01) | |
parser.add_argument('--w3', type=float, default=0.05) | |
parser.add_argument('--fmin', type=float, default=8.0) | |
parser.add_argument('--fmax', type=float, default=96.0) | |
parser.add_argument('--noise', action='store_true') | |
parser.add_argument('--scale', action='store_true') | |
parser.add_argument('--edges', type=int, default=24) | |
parser.add_argument('--restart_prob', type=float, default=0.2) | |
args = parser.parse_args() | |
args.world_size = args.gpus | |
print(args) | |
import os | |
if not os.path.isdir('checkpoints'): | |
os.mkdir('checkpoints') | |
args = parser.parse_args() | |
args.world_size = args.gpus | |
os.environ['MASTER_ADDR'] = 'localhost' | |
os.environ['MASTER_PORT'] = '12356' | |
mp.spawn(train, nprocs=args.gpus, args=(args,)) | |