File size: 5,733 Bytes
b7eedf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# standard library
from pathlib import Path
from typing import *
import sys, os
# third party
import cv2
import numpy as np
import torch
from PIL import Image
from mmengine import Config
# metric 3d
metric3d_path = Path(__file__).resolve().parent
metric3d_mono_path = metric3d_path / 'mono'
sys.path.append(str(metric3d_path))
sys.path.append(str(metric3d_mono_path))
from mono.model.monodepth_model import get_configured_monodepth_model
from mono.utils.running import load_ckpt
from mono.utils.do_test import transform_test_data_scalecano, get_prediction
from mono.utils.mldb import load_data_info, reset_ckpt_path
from mono.utils.transform import gray_to_colormap

__ALL__ = ['Metric3D']

def calculate_radius(mask):
    # 获取矩阵的大小 N
    N = mask.shape[0]
    
    # 找到矩阵的中心点
    center = (N - 1) / 2
    
    # 获取所有值为0的点的坐标
    y, x = np.where(mask == 0)
    
    # 计算这些点到中心的距离
    distances = np.sqrt((x - center) ** 2 + (y - center) ** 2)
    
    # 半径是这些距离的最大值
    radius = distances.max()
    
    return radius

class Metric3D:
    cfg_: Config
    model_: torch.nn.Module

    def __init__(
        self,
        checkpoint: Union[str, Path] = './weights/metric_depth_vit_large_800k.pth',
        model_name: str = 'v2-L',
    ) -> None:
        checkpoint = Path(checkpoint).resolve()
        cfg:Config = self._load_config_(model_name, checkpoint)
        # build model
        model = get_configured_monodepth_model(cfg, )
        model = torch.nn.DataParallel(model).cuda()
        model, _, _, _ = load_ckpt(cfg.load_from, model, strict_match=False)
        model.eval()
        # save to self
        self.cfg_ = cfg
        self.model_ = model

    @torch.no_grad()
    def __call__(
        self,
        rgb_image: Union[np.ndarray, Image.Image, str, Path],
        intrinsic: Union[str, Path, np.ndarray],
        d_max: Optional[float] = 300,
        d_min: Optional[float] = 0,
        margin_mask=None,
        crop_margin=0
    ) -> np.ndarray:
        # read image
        if isinstance(rgb_image, (str, Path)):
            rgb_image = np.array(Image.open(rgb_image))
        elif isinstance(rgb_image, Image.Image):
            rgb_image = np.array(rgb_image)

        if isinstance(intrinsic, (str, Path)):
            intrinsic = np.loadtxt(intrinsic)
        intrinsic = intrinsic[:4]

        # crop margin mask
        if crop_margin != 0:
            original_h, original_w = margin_mask.shape
            # radius = calculate_radius(margin_mask) - crop_margin
            radius = original_h // 2 - crop_margin
            left, right, up, bottom = int(original_w//2-radius), int(original_w//2+radius), int(original_h//2-radius), int(original_h//2+radius)
            rgb_image = rgb_image[up:bottom, left:right]
            h, w = rgb_image.shape[:2]
            intrinsic[2] = w/2
            intrinsic[3] = h/2
            cv2.imwrite("debug.png", rgb_image[:, :, ::-1])
        # get intrinsic
        h, w = rgb_image.shape[:2]

        input_size = (616, 1064)
        scale = min(input_size[0] / h, input_size[1] / w)
        # transform image
        rgb_input, cam_models_stacks, pad, label_scale_factor = \
            transform_test_data_scalecano(rgb_image, intrinsic, self.cfg_.data_basic)
        # predict depth
        normalize_scale = self.cfg_.data_basic.depth_range[1]
        rgb_input = rgb_input.unsqueeze(0)
        pred_depth, output = get_prediction(
            model = self.model_,
            input = rgb_input,
            cam_model = cam_models_stacks,
            pad_info = pad,
            scale_info = label_scale_factor,
            gt_depth = None,
            normalize_scale = normalize_scale,
            ori_shape=[h, w],
        )
        
        # post process
        # pred_depth = (pred_depth > 0) * (pred_depth < 300) * pred_depth

        pred_depth = pred_depth.squeeze().cpu().numpy()
        pred_depth[pred_depth > d_max] = 0
        pred_depth[pred_depth < d_min] = 0

        pred_depth = pred_depth[pad[0] : pred_depth.shape[0] - pad[1], pad[2] : pred_depth.shape[1] - pad[3]]

        canonical_to_real_scale = intrinsic[0] * scale / 1000.0 # 1000.0 is the focal length of canonical camera
        pred_depth = pred_depth * canonical_to_real_scale # now the depth is metric

        # because crop margin at beginning
        if not margin_mask is None:
            final_depth = np.zeros((original_h, original_w))
            pred_depth = cv2.resize(pred_depth, (w, h))
            final_depth[up:bottom, left:right] = pred_depth
        return pred_depth

    def _load_config_(
        self,
        model_name: str,
        checkpoint: Union[str, Path],
    ) -> Config:
        print(f'Loading model {model_name} from {checkpoint}')
        config_path = metric3d_path / 'mono/configs/HourglassDecoder'
        assert model_name in ['v2-L', 'v2-S', 'v2-g'], f"Model {model_name} not supported"
        # load config file
        cfg = Config.fromfile(
            str(config_path / 'vit.raft5.large.py') if model_name == 'v2-L' 
            else str(config_path / 'vit.raft5.small.py') if model_name == 'v2-S' 
            else str(config_path / 'vit.raft5.giant2.py')
        )
        cfg.load_from = str(checkpoint)
        # load data info
        data_info = {}
        load_data_info('data_info', data_info=data_info)
        cfg.mldb_info = data_info
        # update check point info
        reset_ckpt_path(cfg.model, data_info)
        # set distributed
        cfg.distributed = False
        
        return cfg
    
    @staticmethod
    def gray_to_colormap(depth: np.ndarray) -> np.ndarray:
        return gray_to_colormap(depth)