HaWoR / thirdparty /Metric3D /onnx /metric3d_onnx_export.py
ThunderVVV's picture
add thirdparty
b7eedf7
"""
Export the torch hub model to ONNX format. Normalization is done in the model.
"""
import torch
class Metric3DExportModel(torch.nn.Module):
"""
The model for exporting to ONNX format. Add custom preprocessing and postprocessing here.
"""
def __init__(self, meta_arch):
super().__init__()
self.meta_arch = meta_arch
self.register_buffer(
"rgb_mean", torch.tensor([123.675, 116.28, 103.53]).view(1, 3, 1, 1).cuda()
)
self.register_buffer(
"rgb_std", torch.tensor([58.395, 57.12, 57.375]).view(1, 3, 1, 1).cuda()
)
self.input_size = (616, 1064)
def normalize_image(self, image):
image = image - self.rgb_mean
image = image / self.rgb_std
return image
def forward(self, image):
image = self.normalize_image(image)
with torch.no_grad():
pred_depth, confidence, output_dict = self.meta_arch.inference(
{"input": image}
)
return pred_depth
def update_vit_sampling(model):
"""
For ViT models running on some TensorRT version, we need to change the interpolation method from bicubic to bilinear.
"""
import torch.nn as nn
import math
def interpolate_pos_encoding_bilinear(self, x, w, h):
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
pos_embed = self.pos_embed.float()
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
sqrt_N = math.sqrt(N)
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(
0, 3, 1, 2
),
scale_factor=(sx, sy),
mode="bilinear", # Change from bicubic to bilinear
antialias=self.interpolate_antialias,
)
assert int(w0) == patch_pos_embed.shape[-2]
assert int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
previous_dtype
)
model.depth_model.encoder.interpolate_pos_encoding = (
interpolate_pos_encoding_bilinear.__get__(
model.depth_model.encoder, model.depth_model.encoder.__class__
)
)
return model
def main(model_name="metric3d_vit_small", modify_upsample=False):
model = torch.hub.load("yvanyin/metric3d", model_name, pretrain=True)
model.cuda().eval()
if modify_upsample:
model = update_vit_sampling(model)
B = 1
if "vit" in model_name:
dummy_image = torch.randn([B, 3, 616, 1064]).cuda()
else:
dummy_image = torch.randn([B, 3, 544, 1216]).cuda()
export_model = Metric3DExportModel(model)
export_model.eval()
export_model.cuda()
onnx_output = f"{model_name}.onnx"
dummy_input = (dummy_image,)
torch.onnx.export(
export_model,
dummy_input,
onnx_output,
input_names=["image"],
output_names=["pred_depth"],
opset_version=11,
)
if __name__ == "__main__":
from fire import Fire
Fire(main)