Spaces:
Running
Running
""" | |
Export the torch hub model to ONNX format. Normalization is done in the model. | |
""" | |
import torch | |
class Metric3DExportModel(torch.nn.Module): | |
""" | |
The model for exporting to ONNX format. Add custom preprocessing and postprocessing here. | |
""" | |
def __init__(self, meta_arch): | |
super().__init__() | |
self.meta_arch = meta_arch | |
self.register_buffer( | |
"rgb_mean", torch.tensor([123.675, 116.28, 103.53]).view(1, 3, 1, 1).cuda() | |
) | |
self.register_buffer( | |
"rgb_std", torch.tensor([58.395, 57.12, 57.375]).view(1, 3, 1, 1).cuda() | |
) | |
self.input_size = (616, 1064) | |
def normalize_image(self, image): | |
image = image - self.rgb_mean | |
image = image / self.rgb_std | |
return image | |
def forward(self, image): | |
image = self.normalize_image(image) | |
with torch.no_grad(): | |
pred_depth, confidence, output_dict = self.meta_arch.inference( | |
{"input": image} | |
) | |
return pred_depth | |
def update_vit_sampling(model): | |
""" | |
For ViT models running on some TensorRT version, we need to change the interpolation method from bicubic to bilinear. | |
""" | |
import torch.nn as nn | |
import math | |
def interpolate_pos_encoding_bilinear(self, x, w, h): | |
previous_dtype = x.dtype | |
npatch = x.shape[1] - 1 | |
N = self.pos_embed.shape[1] - 1 | |
if npatch == N and w == h: | |
return self.pos_embed | |
pos_embed = self.pos_embed.float() | |
class_pos_embed = pos_embed[:, 0] | |
patch_pos_embed = pos_embed[:, 1:] | |
dim = x.shape[-1] | |
w0 = w // self.patch_size | |
h0 = h // self.patch_size | |
# we add a small number to avoid floating point error in the interpolation | |
# see discussion at https://github.com/facebookresearch/dino/issues/8 | |
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset | |
sqrt_N = math.sqrt(N) | |
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N | |
patch_pos_embed = nn.functional.interpolate( | |
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute( | |
0, 3, 1, 2 | |
), | |
scale_factor=(sx, sy), | |
mode="bilinear", # Change from bicubic to bilinear | |
antialias=self.interpolate_antialias, | |
) | |
assert int(w0) == patch_pos_embed.shape[-2] | |
assert int(h0) == patch_pos_embed.shape[-1] | |
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) | |
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( | |
previous_dtype | |
) | |
model.depth_model.encoder.interpolate_pos_encoding = ( | |
interpolate_pos_encoding_bilinear.__get__( | |
model.depth_model.encoder, model.depth_model.encoder.__class__ | |
) | |
) | |
return model | |
def main(model_name="metric3d_vit_small", modify_upsample=False): | |
model = torch.hub.load("yvanyin/metric3d", model_name, pretrain=True) | |
model.cuda().eval() | |
if modify_upsample: | |
model = update_vit_sampling(model) | |
B = 1 | |
if "vit" in model_name: | |
dummy_image = torch.randn([B, 3, 616, 1064]).cuda() | |
else: | |
dummy_image = torch.randn([B, 3, 544, 1216]).cuda() | |
export_model = Metric3DExportModel(model) | |
export_model.eval() | |
export_model.cuda() | |
onnx_output = f"{model_name}.onnx" | |
dummy_input = (dummy_image,) | |
torch.onnx.export( | |
export_model, | |
dummy_input, | |
onnx_output, | |
input_names=["image"], | |
output_names=["pred_depth"], | |
opset_version=11, | |
) | |
if __name__ == "__main__": | |
from fire import Fire | |
Fire(main) | |