ThunderVVV's picture
add thirdparty
b7eedf7
# standard library
from pathlib import Path
from typing import *
import sys, os
# third party
import cv2
import numpy as np
import torch
from PIL import Image
from mmengine import Config
# metric 3d
metric3d_path = Path(__file__).resolve().parent
metric3d_mono_path = metric3d_path / 'mono'
sys.path.append(str(metric3d_path))
sys.path.append(str(metric3d_mono_path))
from mono.model.monodepth_model import get_configured_monodepth_model
from mono.utils.running import load_ckpt
from mono.utils.do_test import transform_test_data_scalecano, get_prediction
from mono.utils.mldb import load_data_info, reset_ckpt_path
from mono.utils.transform import gray_to_colormap
__ALL__ = ['Metric3D']
def calculate_radius(mask):
# 获取矩阵的大小 N
N = mask.shape[0]
# 找到矩阵的中心点
center = (N - 1) / 2
# 获取所有值为0的点的坐标
y, x = np.where(mask == 0)
# 计算这些点到中心的距离
distances = np.sqrt((x - center) ** 2 + (y - center) ** 2)
# 半径是这些距离的最大值
radius = distances.max()
return radius
class Metric3D:
cfg_: Config
model_: torch.nn.Module
def __init__(
self,
checkpoint: Union[str, Path] = './weights/metric_depth_vit_large_800k.pth',
model_name: str = 'v2-L',
) -> None:
checkpoint = Path(checkpoint).resolve()
cfg:Config = self._load_config_(model_name, checkpoint)
# build model
model = get_configured_monodepth_model(cfg, )
model = torch.nn.DataParallel(model).cuda()
model, _, _, _ = load_ckpt(cfg.load_from, model, strict_match=False)
model.eval()
# save to self
self.cfg_ = cfg
self.model_ = model
@torch.no_grad()
def __call__(
self,
rgb_image: Union[np.ndarray, Image.Image, str, Path],
intrinsic: Union[str, Path, np.ndarray],
d_max: Optional[float] = 300,
d_min: Optional[float] = 0,
margin_mask=None,
crop_margin=0
) -> np.ndarray:
# read image
if isinstance(rgb_image, (str, Path)):
rgb_image = np.array(Image.open(rgb_image))
elif isinstance(rgb_image, Image.Image):
rgb_image = np.array(rgb_image)
if isinstance(intrinsic, (str, Path)):
intrinsic = np.loadtxt(intrinsic)
intrinsic = intrinsic[:4]
# crop margin mask
if crop_margin != 0:
original_h, original_w = margin_mask.shape
# radius = calculate_radius(margin_mask) - crop_margin
radius = original_h // 2 - crop_margin
left, right, up, bottom = int(original_w//2-radius), int(original_w//2+radius), int(original_h//2-radius), int(original_h//2+radius)
rgb_image = rgb_image[up:bottom, left:right]
h, w = rgb_image.shape[:2]
intrinsic[2] = w/2
intrinsic[3] = h/2
cv2.imwrite("debug.png", rgb_image[:, :, ::-1])
# get intrinsic
h, w = rgb_image.shape[:2]
input_size = (616, 1064)
scale = min(input_size[0] / h, input_size[1] / w)
# transform image
rgb_input, cam_models_stacks, pad, label_scale_factor = \
transform_test_data_scalecano(rgb_image, intrinsic, self.cfg_.data_basic)
# predict depth
normalize_scale = self.cfg_.data_basic.depth_range[1]
rgb_input = rgb_input.unsqueeze(0)
pred_depth, output = get_prediction(
model = self.model_,
input = rgb_input,
cam_model = cam_models_stacks,
pad_info = pad,
scale_info = label_scale_factor,
gt_depth = None,
normalize_scale = normalize_scale,
ori_shape=[h, w],
)
# post process
# pred_depth = (pred_depth > 0) * (pred_depth < 300) * pred_depth
pred_depth = pred_depth.squeeze().cpu().numpy()
pred_depth[pred_depth > d_max] = 0
pred_depth[pred_depth < d_min] = 0
pred_depth = pred_depth[pad[0] : pred_depth.shape[0] - pad[1], pad[2] : pred_depth.shape[1] - pad[3]]
canonical_to_real_scale = intrinsic[0] * scale / 1000.0 # 1000.0 is the focal length of canonical camera
pred_depth = pred_depth * canonical_to_real_scale # now the depth is metric
# because crop margin at beginning
if not margin_mask is None:
final_depth = np.zeros((original_h, original_w))
pred_depth = cv2.resize(pred_depth, (w, h))
final_depth[up:bottom, left:right] = pred_depth
return pred_depth
def _load_config_(
self,
model_name: str,
checkpoint: Union[str, Path],
) -> Config:
print(f'Loading model {model_name} from {checkpoint}')
config_path = metric3d_path / 'mono/configs/HourglassDecoder'
assert model_name in ['v2-L', 'v2-S', 'v2-g'], f"Model {model_name} not supported"
# load config file
cfg = Config.fromfile(
str(config_path / 'vit.raft5.large.py') if model_name == 'v2-L'
else str(config_path / 'vit.raft5.small.py') if model_name == 'v2-S'
else str(config_path / 'vit.raft5.giant2.py')
)
cfg.load_from = str(checkpoint)
# load data info
data_info = {}
load_data_info('data_info', data_info=data_info)
cfg.mldb_info = data_info
# update check point info
reset_ckpt_path(cfg.model, data_info)
# set distributed
cfg.distributed = False
return cfg
@staticmethod
def gray_to_colormap(depth: np.ndarray) -> np.ndarray:
return gray_to_colormap(depth)