File size: 3,715 Bytes
b7eedf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Export the torch hub model to ONNX format. Normalization is done in the model.
"""

import torch


class Metric3DExportModel(torch.nn.Module):
    """
    The model for exporting to ONNX format. Add custom preprocessing and postprocessing here.
    """

    def __init__(self, meta_arch):
        super().__init__()
        self.meta_arch = meta_arch
        self.register_buffer(
            "rgb_mean", torch.tensor([123.675, 116.28, 103.53]).view(1, 3, 1, 1).cuda()
        )
        self.register_buffer(
            "rgb_std", torch.tensor([58.395, 57.12, 57.375]).view(1, 3, 1, 1).cuda()
        )
        self.input_size = (616, 1064)

    def normalize_image(self, image):
        image = image - self.rgb_mean
        image = image / self.rgb_std
        return image

    def forward(self, image):
        image = self.normalize_image(image)
        with torch.no_grad():
            pred_depth, confidence, output_dict = self.meta_arch.inference(
                {"input": image}
            )
        return pred_depth


def update_vit_sampling(model):
    """
    For ViT models running on some TensorRT version, we need to change the interpolation method from bicubic to bilinear.
    """
    import torch.nn as nn
    import math

    def interpolate_pos_encoding_bilinear(self, x, w, h):
        previous_dtype = x.dtype
        npatch = x.shape[1] - 1
        N = self.pos_embed.shape[1] - 1
        if npatch == N and w == h:
            return self.pos_embed
        pos_embed = self.pos_embed.float()
        class_pos_embed = pos_embed[:, 0]
        patch_pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        w0 = w // self.patch_size
        h0 = h // self.patch_size
        # we add a small number to avoid floating point error in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset

        sqrt_N = math.sqrt(N)
        sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(
                0, 3, 1, 2
            ),
            scale_factor=(sx, sy),
            mode="bilinear",  # Change from bicubic to bilinear
            antialias=self.interpolate_antialias,
        )

        assert int(w0) == patch_pos_embed.shape[-2]
        assert int(h0) == patch_pos_embed.shape[-1]
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
            previous_dtype
        )

    model.depth_model.encoder.interpolate_pos_encoding = (
        interpolate_pos_encoding_bilinear.__get__(
            model.depth_model.encoder, model.depth_model.encoder.__class__
        )
    )
    return model


def main(model_name="metric3d_vit_small", modify_upsample=False):
    model = torch.hub.load("yvanyin/metric3d", model_name, pretrain=True)
    model.cuda().eval()

    if modify_upsample:
        model = update_vit_sampling(model)

    B = 1
    if "vit" in model_name:
        dummy_image = torch.randn([B, 3, 616, 1064]).cuda()
    else:
        dummy_image = torch.randn([B, 3, 544, 1216]).cuda()

    export_model = Metric3DExportModel(model)
    export_model.eval()
    export_model.cuda()

    onnx_output = f"{model_name}.onnx"
    dummy_input = (dummy_image,)
    torch.onnx.export(
        export_model,
        dummy_input,
        onnx_output,
        input_names=["image"],
        output_names=["pred_depth"],
        opset_version=11,
    )


if __name__ == "__main__":
    from fire import Fire

    Fire(main)