Spaces:
Running
Running
import sys | |
sys.path.insert(0, 'thirdparty/DROID-SLAM/droid_slam') | |
sys.path.insert(0, 'thirdparty/DROID-SLAM') | |
from tqdm import tqdm | |
import numpy as np | |
import torch | |
import os | |
import argparse | |
from PIL import Image | |
import cv2 | |
from glob import glob | |
from droid import Droid | |
from torch.multiprocessing import Process | |
import evo | |
from evo.core.trajectory import PoseTrajectory3D | |
from evo.tools import file_interface | |
from evo.core import sync | |
import evo.main_ape as main_ape | |
from evo.core.metrics import PoseRelation | |
from pycocotools import mask as masktool | |
from torchvision.transforms import Resize | |
# Some default settings for DROID-SLAM | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--imagedir", type=str, help="path to image directory") | |
parser.add_argument("--calib", type=str, help="path to calibration file") | |
parser.add_argument("--t0", default=0, type=int, help="starting frame") | |
parser.add_argument("--stride", default=1, type=int, help="frame stride") | |
parser.add_argument("--weights", default="weights/external/droid.pth") | |
parser.add_argument("--buffer", type=int, default=512) | |
parser.add_argument("--image_size", default=[240, 320]) | |
parser.add_argument("--disable_vis", action="store_true") | |
parser.add_argument("--beta", type=float, default=0.3, help="weight for translation / rotation components of flow") | |
parser.add_argument("--filter_thresh", type=float, default=2.4, help="how much motion before considering new keyframe") | |
parser.add_argument("--warmup", type=int, default=8, help="number of warmup frames") | |
parser.add_argument("--keyframe_thresh", type=float, default=4.0, help="threshold to create a new keyframe") | |
parser.add_argument("--frontend_thresh", type=float, default=16.0, help="add edges between frames whithin this distance") | |
parser.add_argument("--frontend_window", type=int, default=25, help="frontend optimization window") | |
parser.add_argument("--frontend_radius", type=int, default=2, help="force edges between frames within radius") | |
parser.add_argument("--frontend_nms", type=int, default=1, help="non-maximal supression of edges") | |
parser.add_argument("--backend_thresh", type=float, default=22.0) | |
parser.add_argument("--backend_radius", type=int, default=2) | |
parser.add_argument("--backend_nms", type=int, default=3) | |
parser.add_argument("--upsample", action="store_true") | |
parser.add_argument("--reconstruction_path", help="path to saved reconstruction") | |
args = parser.parse_args([]) | |
args.stereo = False | |
args.upsample = True | |
args.disable_vis = True | |
torch.multiprocessing.set_start_method('spawn') | |
def est_calib(imagedir): | |
""" Roughly estimate intrinsics by image dimensions """ | |
if isinstance(imagedir, list): | |
imgfiles = imagedir | |
else: | |
imgfiles = sorted(glob(f'{imagedir}/*.jpg')) | |
image = cv2.imread(imgfiles[0]) | |
h0, w0, _ = image.shape | |
focal = np.max([h0, w0]) | |
cx, cy = w0/2., h0/2. | |
calib = [focal, focal, cx, cy] | |
return calib | |
def get_dimention(imagedir): | |
""" Get proper image dimension for DROID """ | |
if isinstance(imagedir, list): | |
imgfiles = imagedir | |
else: | |
imgfiles = sorted(glob(f'{imagedir}/*.jpg')) | |
image = cv2.imread(imgfiles[0]) | |
h0, w0, _ = image.shape | |
h1 = int(h0 * np.sqrt((384 * 512) / (h0 * w0))) | |
w1 = int(w0 * np.sqrt((384 * 512) / (h0 * w0))) | |
image = cv2.resize(image, (w1, h1)) | |
image = image[:h1-h1%8, :w1-w1%8] | |
H, W, _ = image.shape | |
return H, W | |
def image_stream(imagedir, calib, stride, max_frame=None): | |
""" Image generator for DROID """ | |
fx, fy, cx, cy = calib[:4] | |
K = np.eye(3) | |
K[0,0] = fx | |
K[0,2] = cx | |
K[1,1] = fy | |
K[1,2] = cy | |
if isinstance(imagedir, list): | |
image_list = imagedir | |
else: | |
image_list = sorted(glob(f'{imagedir}/*.jpg')) | |
image_list = image_list[::stride] | |
if max_frame is not None: | |
image_list = image_list[:max_frame] | |
for t, imfile in enumerate(image_list): | |
image = cv2.imread(imfile) | |
if len(calib) > 4: | |
image = cv2.undistort(image, K, calib[4:]) | |
h0, w0, _ = image.shape | |
h1 = int(h0 * np.sqrt((384 * 512) / (h0 * w0))) | |
w1 = int(w0 * np.sqrt((384 * 512) / (h0 * w0))) | |
image = cv2.resize(image, (w1, h1)) | |
image = image[:h1-h1%8, :w1-w1%8] | |
image = torch.as_tensor(image).permute(2, 0, 1) | |
intrinsics = torch.as_tensor([fx, fy, cx, cy]) | |
intrinsics[0::2] *= (w1 / w0) | |
intrinsics[1::2] *= (h1 / h0) | |
yield t, image[None], intrinsics | |
def run_slam(imagedir, masks, calib=None, depth=None, stride=1, | |
filter_thresh=2.4, disable_vis=True): | |
""" Maksed DROID-SLAM """ | |
droid = None | |
depth = None | |
args.filter_thresh = filter_thresh | |
args.disable_vis = disable_vis | |
masks = masks[::stride] | |
img_msks, conf_msks = preprocess_masks(imagedir, masks) | |
if calib is None: | |
calib = est_calib(imagedir) | |
for (t, image, intrinsics) in tqdm(image_stream(imagedir, calib, stride)): | |
if droid is None: | |
args.image_size = [image.shape[2], image.shape[3]] | |
droid = Droid(args) | |
img_msk = img_msks[t] | |
conf_msk = conf_msks[t] | |
image = image * (img_msk < 0.5) | |
# cv2.imwrite('debug.png', image[0].permute(1, 2, 0).numpy()) | |
droid.track(t, image, intrinsics=intrinsics, depth=depth, mask=conf_msk) | |
traj = droid.terminate(image_stream(imagedir, calib, stride)) | |
return droid, traj | |
def run_droid_slam(imagedir, calib=None, depth=None, stride=1, | |
filter_thresh=2.4, disable_vis=True): | |
""" Maksed DROID-SLAM """ | |
droid = None | |
depth = None | |
args.filter_thresh = filter_thresh | |
args.disable_vis = disable_vis | |
if calib is None: | |
calib = est_calib(imagedir) | |
for (t, image, intrinsics) in tqdm(image_stream(imagedir, calib, stride)): | |
if droid is None: | |
args.image_size = [image.shape[2], image.shape[3]] | |
droid = Droid(args) | |
droid.track(t, image, intrinsics=intrinsics, depth=depth) | |
traj = droid.terminate(image_stream(imagedir, calib, stride)) | |
return droid, traj | |
def eval_slam(traj_est, cam_t, cam_q, return_traj=True, correct_scale=False, align=True, align_origin=False): | |
""" Evaluation for SLAM """ | |
tstamps = np.array([i for i in range(len(traj_est))], dtype=np.float32) | |
traj_est = PoseTrajectory3D( | |
positions_xyz=traj_est[:,:3], | |
orientations_quat_wxyz=traj_est[:,3:], | |
timestamps=tstamps) | |
traj_ref = PoseTrajectory3D( | |
positions_xyz=cam_t.copy(), | |
orientations_quat_wxyz=cam_q.copy(), | |
timestamps=tstamps) | |
traj_ref, traj_est = sync.associate_trajectories(traj_ref, traj_est) | |
result = main_ape.ape(traj_ref, traj_est, est_name='traj', | |
pose_relation=PoseRelation.translation_part, align=align, align_origin=align_origin, | |
correct_scale=correct_scale) | |
stats = result.stats | |
if return_traj: | |
return stats, traj_ref, traj_est | |
return stats | |
def test_slam(imagedir, img_msks, conf_msks, calib, stride=10, max_frame=50): | |
""" Shorter SLAM step to test reprojection error """ | |
args = parser.parse_args([]) | |
args.stereo = False | |
args.upsample = False | |
args.disable_vis = True | |
args.frontend_window = 10 | |
args.frontend_thresh = 10 | |
droid = None | |
for (t, image, intrinsics) in image_stream(imagedir, calib, stride, max_frame): | |
if droid is None: | |
args.image_size = [image.shape[2], image.shape[3]] | |
droid = Droid(args) | |
img_msk = img_msks[t] | |
conf_msk = conf_msks[t] | |
image = image * (img_msk < 0.5) | |
droid.track(t, image, intrinsics=intrinsics, mask=conf_msk) | |
reprojection_error = droid.compute_error() | |
del droid | |
return reprojection_error | |
def search_focal_length(img_folder, masks, stride=10, max_frame=50, | |
low=500, high=1500, step=100): | |
""" Search for a good focal length by SLAM reprojection error """ | |
masks = masks[::stride] | |
masks = masks[:max_frame] | |
img_msks, conf_msks = preprocess_masks(img_folder, masks) | |
# default estimate | |
calib = np.array(est_calib(img_folder)) | |
best_focal = calib[0] | |
best_err = test_slam(img_folder, img_msks, conf_msks, | |
stride=stride, calib=calib, max_frame=max_frame) | |
# search based on slam reprojection error | |
for focal in range(low, high, step): | |
calib[:2] = focal | |
err = test_slam(img_folder, img_msks, conf_msks, | |
stride=stride, calib=calib, max_frame=max_frame) | |
if err < best_err: | |
best_err = err | |
best_focal = focal | |
print('Best focal length:', best_focal) | |
return best_focal | |
def preprocess_masks(img_folder, masks): | |
""" Resize masks for masked droid """ | |
H, W = get_dimention(img_folder) | |
resize_1 = Resize((H, W), antialias=True) | |
resize_2 = Resize((H//8, W//8), antialias=True) | |
img_msks = [] | |
for i in range(0, len(masks), 500): | |
m = resize_1(masks[i:i+500]) | |
img_msks.append(m) | |
img_msks = torch.cat(img_msks) | |
conf_msks = [] | |
for i in range(0, len(masks), 500): | |
m = resize_2(masks[i:i+500]) | |
conf_msks.append(m) | |
conf_msks = torch.cat(conf_msks) | |
return img_msks, conf_msks | |