Spaces:
Running
on
Zero
Running
on
Zero
first version
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +710 -4
- example_building/0.jpg +0 -0
- example_building/1.jpg +0 -0
- example_building/2.jpg +0 -0
- example_building/3.jpg +0 -0
- example_building/4.jpg +0 -0
- example_building/5.jpg +0 -0
- example_building/6.jpg +0 -0
- example_building/7.jpg +0 -0
- example_building/8.jpg +0 -0
- requirements.txt +19 -0
- vggt/heads/__pycache__/camera_head.cpython-310.pyc +0 -0
- vggt/heads/__pycache__/camera_head.cpython-311.pyc +0 -0
- vggt/heads/__pycache__/camera_head.cpython-312.pyc +0 -0
- vggt/heads/__pycache__/dpt_head.cpython-310.pyc +0 -0
- vggt/heads/__pycache__/dpt_head.cpython-311.pyc +0 -0
- vggt/heads/__pycache__/dpt_head.cpython-312.pyc +0 -0
- vggt/heads/__pycache__/head_act.cpython-310.pyc +0 -0
- vggt/heads/__pycache__/head_act.cpython-311.pyc +0 -0
- vggt/heads/__pycache__/head_act.cpython-312.pyc +0 -0
- vggt/heads/__pycache__/track_head.cpython-310.pyc +0 -0
- vggt/heads/__pycache__/track_head.cpython-311.pyc +0 -0
- vggt/heads/__pycache__/track_head.cpython-312.pyc +0 -0
- vggt/heads/__pycache__/utils.cpython-310.pyc +0 -0
- vggt/heads/__pycache__/utils.cpython-311.pyc +0 -0
- vggt/heads/__pycache__/utils.cpython-312.pyc +0 -0
- vggt/heads/camera_head.py +162 -0
- vggt/heads/dpt_head.py +497 -0
- vggt/heads/head_act.py +125 -0
- vggt/heads/track_head.py +108 -0
- vggt/heads/track_modules/__init__.py +5 -0
- vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc +0 -0
- vggt/heads/track_modules/__pycache__/__init__.cpython-311.pyc +0 -0
- vggt/heads/track_modules/__pycache__/__init__.cpython-312.pyc +0 -0
- vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc +0 -0
- vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-311.pyc +0 -0
- vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-312.pyc +0 -0
- vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc +0 -0
- vggt/heads/track_modules/__pycache__/blocks.cpython-311.pyc +0 -0
- vggt/heads/track_modules/__pycache__/blocks.cpython-312.pyc +0 -0
- vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc +0 -0
- vggt/heads/track_modules/__pycache__/modules.cpython-311.pyc +0 -0
- vggt/heads/track_modules/__pycache__/modules.cpython-312.pyc +0 -0
- vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc +0 -0
- vggt/heads/track_modules/__pycache__/utils.cpython-311.pyc +0 -0
- vggt/heads/track_modules/__pycache__/utils.cpython-312.pyc +0 -0
- vggt/heads/track_modules/base_track_predictor.py +209 -0
- vggt/heads/track_modules/blocks.py +246 -0
- vggt/heads/track_modules/modules.py +218 -0
- vggt/heads/track_modules/utils.py +226 -0
app.py
CHANGED
@@ -1,7 +1,713 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
demo
|
7 |
-
demo.launch()
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import cv2
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
import gradio as gr
|
12 |
+
import sys
|
13 |
+
import shutil
|
14 |
+
from datetime import datetime
|
15 |
+
import glob
|
16 |
+
import gc
|
17 |
+
import time
|
18 |
+
|
19 |
+
from visual_util import predictions_to_glb
|
20 |
+
from vggt.models.vggt import VGGT
|
21 |
+
from vggt.utils.load_fn import load_and_preprocess_images
|
22 |
+
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
23 |
+
from vggt.utils.geometry import unproject_depth_map_to_point_map
|
24 |
+
|
25 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
26 |
+
|
27 |
+
print("Initializing and loading VGGT model...")
|
28 |
+
# model = VGGT.from_pretrained("facebook/VGGT-1B") # another way to load the model
|
29 |
+
|
30 |
+
model_path = "https://huggingface.co/lch01/StreamVGGT/blob/main/checkpoints.pth"
|
31 |
+
model = VGGT(use_causal_global=True, use_distil=True)
|
32 |
+
ckpt = torch.load(torch.hub.load_state_dict_from_url(model_path), map_location=device)
|
33 |
+
model.load_state_dict(ckpt, strict=True)
|
34 |
+
model = model.to(device)
|
35 |
+
model.eval()
|
36 |
+
del ckpt
|
37 |
+
|
38 |
+
|
39 |
+
# -------------------------------------------------------------------------
|
40 |
+
# 1) Core model inference
|
41 |
+
# -------------------------------------------------------------------------
|
42 |
+
def run_model(target_dir, model) -> dict:
|
43 |
+
"""
|
44 |
+
Run the VGGT model on images in the 'target_dir/images' folder and return predictions.
|
45 |
+
"""
|
46 |
+
print(f"Processing images from {target_dir}")
|
47 |
+
|
48 |
+
# Device check
|
49 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
50 |
+
if not torch.cuda.is_available():
|
51 |
+
raise ValueError("CUDA is not available. Check your environment.")
|
52 |
+
|
53 |
+
# Move model to device
|
54 |
+
model = model.to(device)
|
55 |
+
model.eval()
|
56 |
+
|
57 |
+
# Load and preprocess images
|
58 |
+
image_names = glob.glob(os.path.join(target_dir, "images", "*"))
|
59 |
+
image_names = sorted(image_names)
|
60 |
+
print(f"Found {len(image_names)} images")
|
61 |
+
if len(image_names) == 0:
|
62 |
+
raise ValueError("No images found. Check your upload.")
|
63 |
+
|
64 |
+
images = load_and_preprocess_images(image_names).to(device)
|
65 |
+
print(f"Preprocessed images shape: {images.shape}")
|
66 |
+
|
67 |
+
# Run inference
|
68 |
+
print("Running inference...")
|
69 |
+
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
70 |
+
|
71 |
+
with torch.no_grad():
|
72 |
+
with torch.cuda.amp.autocast(dtype=dtype):
|
73 |
+
predictions = model(images)
|
74 |
+
|
75 |
+
# Convert pose encoding to extrinsic and intrinsic matrices
|
76 |
+
print("Converting pose encoding to extrinsic and intrinsic matrices...")
|
77 |
+
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
|
78 |
+
predictions["extrinsic"] = extrinsic
|
79 |
+
predictions["intrinsic"] = intrinsic
|
80 |
+
|
81 |
+
# Convert tensors to numpy
|
82 |
+
for key in predictions.keys():
|
83 |
+
if isinstance(predictions[key], torch.Tensor):
|
84 |
+
predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension
|
85 |
+
predictions['pose_enc_list'] = None # remove pose_enc_list
|
86 |
+
|
87 |
+
# Generate world points from depth map
|
88 |
+
print("Computing world points from depth map...")
|
89 |
+
depth_map = predictions["depth"] # (S, H, W, 1)
|
90 |
+
world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"])
|
91 |
+
predictions["world_points_from_depth"] = world_points
|
92 |
+
|
93 |
+
# Clean up
|
94 |
+
torch.cuda.empty_cache()
|
95 |
+
return predictions
|
96 |
+
|
97 |
+
|
98 |
+
# -------------------------------------------------------------------------
|
99 |
+
# 2) Handle uploaded video/images --> produce target_dir + images
|
100 |
+
# -------------------------------------------------------------------------
|
101 |
+
def handle_uploads(input_video, input_images):
|
102 |
+
"""
|
103 |
+
Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
|
104 |
+
images or extracted frames from video into it. Return (target_dir, image_paths).
|
105 |
+
"""
|
106 |
+
start_time = time.time()
|
107 |
+
gc.collect()
|
108 |
+
torch.cuda.empty_cache()
|
109 |
+
|
110 |
+
# Create a unique folder name
|
111 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
112 |
+
target_dir = f"input_images_{timestamp}"
|
113 |
+
target_dir_images = os.path.join(target_dir, "images")
|
114 |
+
|
115 |
+
# Clean up if somehow that folder already exists
|
116 |
+
if os.path.exists(target_dir):
|
117 |
+
shutil.rmtree(target_dir)
|
118 |
+
os.makedirs(target_dir)
|
119 |
+
os.makedirs(target_dir_images)
|
120 |
+
|
121 |
+
image_paths = []
|
122 |
+
|
123 |
+
# --- Handle images ---
|
124 |
+
if input_images is not None:
|
125 |
+
for file_data in input_images:
|
126 |
+
if isinstance(file_data, dict) and "name" in file_data:
|
127 |
+
file_path = file_data["name"]
|
128 |
+
else:
|
129 |
+
file_path = file_data
|
130 |
+
dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
|
131 |
+
shutil.copy(file_path, dst_path)
|
132 |
+
image_paths.append(dst_path)
|
133 |
+
|
134 |
+
# --- Handle video ---
|
135 |
+
if input_video is not None:
|
136 |
+
if isinstance(input_video, dict) and "name" in input_video:
|
137 |
+
video_path = input_video["name"]
|
138 |
+
else:
|
139 |
+
video_path = input_video
|
140 |
+
|
141 |
+
vs = cv2.VideoCapture(video_path)
|
142 |
+
fps = vs.get(cv2.CAP_PROP_FPS)
|
143 |
+
frame_interval = int(fps * 1) # 1 frame/sec
|
144 |
+
|
145 |
+
count = 0
|
146 |
+
video_frame_num = 0
|
147 |
+
while True:
|
148 |
+
gotit, frame = vs.read()
|
149 |
+
if not gotit:
|
150 |
+
break
|
151 |
+
count += 1
|
152 |
+
if count % frame_interval == 0:
|
153 |
+
image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
|
154 |
+
cv2.imwrite(image_path, frame)
|
155 |
+
image_paths.append(image_path)
|
156 |
+
video_frame_num += 1
|
157 |
+
|
158 |
+
# Sort final images for gallery
|
159 |
+
image_paths = sorted(image_paths)
|
160 |
+
|
161 |
+
end_time = time.time()
|
162 |
+
print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds")
|
163 |
+
return target_dir, image_paths
|
164 |
+
|
165 |
+
|
166 |
+
# -------------------------------------------------------------------------
|
167 |
+
# 3) Update gallery on upload
|
168 |
+
# -------------------------------------------------------------------------
|
169 |
+
def update_gallery_on_upload(input_video, input_images):
|
170 |
+
"""
|
171 |
+
Whenever user uploads or changes files, immediately handle them
|
172 |
+
and show in the gallery. Return (target_dir, image_paths).
|
173 |
+
If nothing is uploaded, returns "None" and empty list.
|
174 |
+
"""
|
175 |
+
if not input_video and not input_images:
|
176 |
+
return None, None, None, None
|
177 |
+
target_dir, image_paths = handle_uploads(input_video, input_images)
|
178 |
+
return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing."
|
179 |
+
|
180 |
+
|
181 |
+
# -------------------------------------------------------------------------
|
182 |
+
# 4) Reconstruction: uses the target_dir plus any viz parameters
|
183 |
+
# -------------------------------------------------------------------------
|
184 |
+
def gradio_demo(
|
185 |
+
target_dir,
|
186 |
+
conf_thres=3.0,
|
187 |
+
frame_filter="All",
|
188 |
+
mask_black_bg=False,
|
189 |
+
mask_white_bg=False,
|
190 |
+
show_cam=True,
|
191 |
+
mask_sky=False,
|
192 |
+
prediction_mode="Pointmap Regression",
|
193 |
+
):
|
194 |
+
"""
|
195 |
+
Perform reconstruction using the already-created target_dir/images.
|
196 |
+
"""
|
197 |
+
if not os.path.isdir(target_dir) or target_dir == "None":
|
198 |
+
return None, "No valid target directory found. Please upload first.", None, None
|
199 |
+
|
200 |
+
start_time = time.time()
|
201 |
+
gc.collect()
|
202 |
+
torch.cuda.empty_cache()
|
203 |
+
|
204 |
+
# Prepare frame_filter dropdown
|
205 |
+
target_dir_images = os.path.join(target_dir, "images")
|
206 |
+
all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
|
207 |
+
all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
|
208 |
+
frame_filter_choices = ["All"] + all_files
|
209 |
+
|
210 |
+
print("Running run_model...")
|
211 |
+
with torch.no_grad():
|
212 |
+
predictions = run_model(target_dir, model)
|
213 |
+
|
214 |
+
# Save predictions
|
215 |
+
prediction_save_path = os.path.join(target_dir, "predictions.npz")
|
216 |
+
np.savez(prediction_save_path, **predictions)
|
217 |
+
|
218 |
+
# Handle None frame_filter
|
219 |
+
if frame_filter is None:
|
220 |
+
frame_filter = "All"
|
221 |
+
|
222 |
+
# Build a GLB file name
|
223 |
+
glbfile = os.path.join(
|
224 |
+
target_dir,
|
225 |
+
f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
|
226 |
+
)
|
227 |
+
|
228 |
+
# Convert predictions to GLB
|
229 |
+
glbscene = predictions_to_glb(
|
230 |
+
predictions,
|
231 |
+
conf_thres=conf_thres,
|
232 |
+
filter_by_frames=frame_filter,
|
233 |
+
mask_black_bg=mask_black_bg,
|
234 |
+
mask_white_bg=mask_white_bg,
|
235 |
+
show_cam=show_cam,
|
236 |
+
mask_sky=mask_sky,
|
237 |
+
target_dir=target_dir,
|
238 |
+
prediction_mode=prediction_mode,
|
239 |
+
)
|
240 |
+
glbscene.export(file_obj=glbfile)
|
241 |
+
|
242 |
+
# Cleanup
|
243 |
+
del predictions
|
244 |
+
gc.collect()
|
245 |
+
torch.cuda.empty_cache()
|
246 |
+
|
247 |
+
end_time = time.time()
|
248 |
+
print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
|
249 |
+
log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
|
250 |
+
|
251 |
+
return glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True)
|
252 |
+
|
253 |
+
|
254 |
+
# -------------------------------------------------------------------------
|
255 |
+
# 5) Helper functions for UI resets + re-visualization
|
256 |
+
# -------------------------------------------------------------------------
|
257 |
+
def clear_fields():
|
258 |
+
"""
|
259 |
+
Clears the 3D viewer, the stored target_dir, and empties the gallery.
|
260 |
+
"""
|
261 |
+
return None
|
262 |
+
|
263 |
+
|
264 |
+
def update_log():
|
265 |
+
"""
|
266 |
+
Display a quick log message while waiting.
|
267 |
+
"""
|
268 |
+
return "Loading and Reconstructing..."
|
269 |
+
|
270 |
+
|
271 |
+
def update_visualization(
|
272 |
+
target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example
|
273 |
+
):
|
274 |
+
"""
|
275 |
+
Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
|
276 |
+
and return it for the 3D viewer. If is_example == "True", skip.
|
277 |
+
"""
|
278 |
+
|
279 |
+
# If it's an example click, skip as requested
|
280 |
+
if is_example == "True":
|
281 |
+
return None, "No reconstruction available. Please click the Reconstruct button first."
|
282 |
+
|
283 |
+
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
|
284 |
+
return None, "No reconstruction available. Please click the Reconstruct button first."
|
285 |
+
|
286 |
+
predictions_path = os.path.join(target_dir, "predictions.npz")
|
287 |
+
if not os.path.exists(predictions_path):
|
288 |
+
return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first."
|
289 |
+
|
290 |
+
key_list = [
|
291 |
+
"pose_enc",
|
292 |
+
"depth",
|
293 |
+
"depth_conf",
|
294 |
+
"world_points",
|
295 |
+
"world_points_conf",
|
296 |
+
"images",
|
297 |
+
"extrinsic",
|
298 |
+
"intrinsic",
|
299 |
+
"world_points_from_depth",
|
300 |
+
]
|
301 |
+
|
302 |
+
loaded = np.load(predictions_path)
|
303 |
+
predictions = {key: np.array(loaded[key]) for key in key_list}
|
304 |
+
|
305 |
+
glbfile = os.path.join(
|
306 |
+
target_dir,
|
307 |
+
f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
|
308 |
+
)
|
309 |
+
|
310 |
+
if not os.path.exists(glbfile):
|
311 |
+
glbscene = predictions_to_glb(
|
312 |
+
predictions,
|
313 |
+
conf_thres=conf_thres,
|
314 |
+
filter_by_frames=frame_filter,
|
315 |
+
mask_black_bg=mask_black_bg,
|
316 |
+
mask_white_bg=mask_white_bg,
|
317 |
+
show_cam=show_cam,
|
318 |
+
mask_sky=mask_sky,
|
319 |
+
target_dir=target_dir,
|
320 |
+
prediction_mode=prediction_mode,
|
321 |
+
)
|
322 |
+
glbscene.export(file_obj=glbfile)
|
323 |
+
|
324 |
+
return glbfile, "Updating Visualization"
|
325 |
+
|
326 |
+
# -------------------------------------------------------------------------
|
327 |
+
# Example images
|
328 |
+
# -------------------------------------------------------------------------
|
329 |
+
|
330 |
+
|
331 |
+
def get_examples_from_folder(images_folder):
|
332 |
+
"""
|
333 |
+
Create an example using all JPG/JPEG files from the specified folder.
|
334 |
+
No caching, directly uses the images from the folder.
|
335 |
+
"""
|
336 |
+
examples = []
|
337 |
+
|
338 |
+
if not os.path.exists(images_folder):
|
339 |
+
print(f"Warning: Images folder {images_folder} does not exist.")
|
340 |
+
return examples
|
341 |
+
|
342 |
+
image_files = []
|
343 |
+
for ext in ['*.jpg', '*.jpeg', '*.JPG', '*.JPEG', '*.png', '*.PNG']:
|
344 |
+
image_files.extend(glob.glob(os.path.join(images_folder, ext)))
|
345 |
+
|
346 |
+
image_files = sorted(image_files)
|
347 |
+
|
348 |
+
if not image_files:
|
349 |
+
print(f"Warning: No images found in {images_folder}.")
|
350 |
+
return examples
|
351 |
+
|
352 |
+
num_images = len(image_files)
|
353 |
+
print(f"Found {num_images} images in {images_folder}")
|
354 |
+
|
355 |
+
example = [
|
356 |
+
None,
|
357 |
+
str(num_images),
|
358 |
+
image_files,
|
359 |
+
20.0,
|
360 |
+
False,
|
361 |
+
False,
|
362 |
+
True,
|
363 |
+
False,
|
364 |
+
"Depthmap and Camera Branch",
|
365 |
+
"True"
|
366 |
+
]
|
367 |
+
|
368 |
+
examples.append(example)
|
369 |
+
return examples
|
370 |
+
|
371 |
+
building_folder = "example_building/"
|
372 |
+
|
373 |
+
# -------------------------------------------------------------------------
|
374 |
+
# 6) Build Gradio UI
|
375 |
+
# -------------------------------------------------------------------------
|
376 |
+
theme = gr.themes.Ocean()
|
377 |
+
theme.set(
|
378 |
+
checkbox_label_background_fill_selected="*button_primary_background_fill",
|
379 |
+
checkbox_label_text_color_selected="*button_primary_text_color",
|
380 |
+
)
|
381 |
+
|
382 |
+
with gr.Blocks(
|
383 |
+
theme=theme,
|
384 |
+
css="""
|
385 |
+
.custom-log * {
|
386 |
+
font-style: italic;
|
387 |
+
font-size: 22px !important;
|
388 |
+
background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
|
389 |
+
-webkit-background-clip: text;
|
390 |
+
background-clip: text;
|
391 |
+
font-weight: bold !important;
|
392 |
+
color: transparent !important;
|
393 |
+
text-align: center !important;
|
394 |
+
}
|
395 |
+
|
396 |
+
.example-log * {
|
397 |
+
font-style: italic;
|
398 |
+
font-size: 16px !important;
|
399 |
+
background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
|
400 |
+
-webkit-background-clip: text;
|
401 |
+
background-clip: text;
|
402 |
+
color: transparent !important;
|
403 |
+
}
|
404 |
+
|
405 |
+
#my_radio .wrap {
|
406 |
+
display: flex;
|
407 |
+
flex-wrap: nowrap;
|
408 |
+
justify-content: center;
|
409 |
+
align-items: center;
|
410 |
+
}
|
411 |
+
|
412 |
+
#my_radio .wrap label {
|
413 |
+
display: flex;
|
414 |
+
width: 50%;
|
415 |
+
justify-content: center;
|
416 |
+
align-items: center;
|
417 |
+
margin: 0;
|
418 |
+
padding: 10px 0;
|
419 |
+
box-sizing: border-box;
|
420 |
+
}
|
421 |
+
""",
|
422 |
+
) as demo:
|
423 |
+
# Instead of gr.State, we use a hidden Textbox:
|
424 |
+
is_example = gr.Textbox(label="is_example", visible=False, value="None")
|
425 |
+
num_images = gr.Textbox(label="num_images", visible=False, value="None")
|
426 |
+
|
427 |
+
gr.HTML(
|
428 |
+
"""
|
429 |
+
<h1>🏛️ VGGT: Visual Geometry Grounded Transformer</h1>
|
430 |
+
<p>
|
431 |
+
<a href="https://github.com/facebookresearch/vggt">🐙 GitHub Repository</a> |
|
432 |
+
<a href="#">Project Page</a>
|
433 |
+
</p>
|
434 |
+
|
435 |
+
<div style="font-size: 16px; line-height: 1.5;">
|
436 |
+
<p>Upload a video or a set of images to create a 3D reconstruction of a scene or object. VGGT takes these images and generates a 3D point cloud, along with estimated camera poses.</p>
|
437 |
+
|
438 |
+
<h3>Getting Started:</h3>
|
439 |
+
<ol>
|
440 |
+
<li><strong>Upload Your Data:</strong> Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).</li>
|
441 |
+
<li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li>
|
442 |
+
<li><strong>Reconstruct:</strong> Click the "Reconstruct" button to start the 3D reconstruction process.</li>
|
443 |
+
<li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for a large number of input images.</li>
|
444 |
+
<li>
|
445 |
+
<strong>Adjust Visualization (Optional):</strong>
|
446 |
+
After reconstruction, you can fine-tune the visualization using the options below
|
447 |
+
<details style="display:inline;">
|
448 |
+
<summary style="display:inline;">(<strong>click to expand</strong>):</summary>
|
449 |
+
<ul>
|
450 |
+
<li><em>Confidence Threshold:</em> Adjust the filtering of points based on confidence.</li>
|
451 |
+
<li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li>
|
452 |
+
<li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li>
|
453 |
+
<li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li>
|
454 |
+
<li><em>Select a Prediction Mode:</em> Choose between "Depthmap and Camera Branch" or "Pointmap Branch."</li>
|
455 |
+
</ul>
|
456 |
+
</details>
|
457 |
+
</li>
|
458 |
+
</ol>
|
459 |
+
<p><strong style="color: #0ea5e9;">Please note:</strong> <span style="color: #0ea5e9; font-weight: bold;">VGGT typically reconstructs a scene in less than 1 second. However, visualizing 3D points may take tens of seconds due to third-party rendering, which are independent of VGGT's processing time. </span></p>
|
460 |
+
</div>
|
461 |
+
"""
|
462 |
+
)
|
463 |
+
|
464 |
+
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
|
465 |
+
|
466 |
+
with gr.Row():
|
467 |
+
with gr.Column(scale=2):
|
468 |
+
input_video = gr.Video(label="Upload Video", interactive=True)
|
469 |
+
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
|
470 |
+
|
471 |
+
image_gallery = gr.Gallery(
|
472 |
+
label="Preview",
|
473 |
+
columns=4,
|
474 |
+
height="300px",
|
475 |
+
show_download_button=True,
|
476 |
+
object_fit="contain",
|
477 |
+
preview=True,
|
478 |
+
)
|
479 |
+
|
480 |
+
with gr.Column(scale=4):
|
481 |
+
with gr.Column():
|
482 |
+
gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**")
|
483 |
+
log_output = gr.Markdown(
|
484 |
+
"Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"]
|
485 |
+
)
|
486 |
+
reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5)
|
487 |
+
|
488 |
+
with gr.Row():
|
489 |
+
submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
|
490 |
+
clear_btn = gr.ClearButton(
|
491 |
+
[input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery],
|
492 |
+
scale=1,
|
493 |
+
)
|
494 |
+
|
495 |
+
with gr.Row():
|
496 |
+
prediction_mode = gr.Radio(
|
497 |
+
["Depthmap and Camera Branch", "Pointmap Branch"],
|
498 |
+
label="Select a Prediction Mode",
|
499 |
+
value="Depthmap and Camera Branch",
|
500 |
+
scale=1,
|
501 |
+
elem_id="my_radio",
|
502 |
+
)
|
503 |
+
|
504 |
+
with gr.Row():
|
505 |
+
conf_thres = gr.Slider(minimum=0, maximum=100, value=50, step=0.1, label="Confidence Threshold (%)")
|
506 |
+
frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame")
|
507 |
+
with gr.Column():
|
508 |
+
show_cam = gr.Checkbox(label="Show Camera", value=True)
|
509 |
+
mask_sky = gr.Checkbox(label="Filter Sky", value=False)
|
510 |
+
mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
|
511 |
+
mask_white_bg = gr.Checkbox(label="Filter White Background", value=False)
|
512 |
+
|
513 |
+
# ---------------------- Examples section ----------------------
|
514 |
+
examples = get_examples_from_folder(building_folder)
|
515 |
+
|
516 |
+
def example_pipeline(
|
517 |
+
input_video,
|
518 |
+
num_images_str,
|
519 |
+
input_images,
|
520 |
+
conf_thres,
|
521 |
+
mask_black_bg,
|
522 |
+
mask_white_bg,
|
523 |
+
show_cam,
|
524 |
+
mask_sky,
|
525 |
+
prediction_mode,
|
526 |
+
is_example_str,
|
527 |
+
):
|
528 |
+
"""
|
529 |
+
1) Copy example images to new target_dir
|
530 |
+
2) Reconstruct
|
531 |
+
3) Return model3D + logs + new_dir + updated dropdown + gallery
|
532 |
+
We do NOT return is_example. It's just an input.
|
533 |
+
"""
|
534 |
+
target_dir, image_paths = handle_uploads(input_video, input_images)
|
535 |
+
# Always use "All" for frame_filter in examples
|
536 |
+
frame_filter = "All"
|
537 |
+
glbfile, log_msg, dropdown = gradio_demo(
|
538 |
+
target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode
|
539 |
+
)
|
540 |
+
return glbfile, log_msg, target_dir, dropdown, image_paths
|
541 |
+
|
542 |
+
gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
|
543 |
+
|
544 |
+
gr.Examples(
|
545 |
+
examples=examples,
|
546 |
+
inputs=[
|
547 |
+
input_video,
|
548 |
+
num_images,
|
549 |
+
input_images,
|
550 |
+
conf_thres,
|
551 |
+
mask_black_bg,
|
552 |
+
mask_white_bg,
|
553 |
+
show_cam,
|
554 |
+
mask_sky,
|
555 |
+
prediction_mode,
|
556 |
+
is_example,
|
557 |
+
],
|
558 |
+
outputs=[reconstruction_output, log_output, target_dir_output, frame_filter, image_gallery],
|
559 |
+
fn=example_pipeline,
|
560 |
+
cache_examples=False,
|
561 |
+
examples_per_page=50,
|
562 |
+
)
|
563 |
+
|
564 |
+
# -------------------------------------------------------------------------
|
565 |
+
# "Reconstruct" button logic:
|
566 |
+
# - Clear fields
|
567 |
+
# - Update log
|
568 |
+
# - gradio_demo(...) with the existing target_dir
|
569 |
+
# - Then set is_example = "False"
|
570 |
+
# -------------------------------------------------------------------------
|
571 |
+
submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
|
572 |
+
fn=update_log, inputs=[], outputs=[log_output]
|
573 |
+
).then(
|
574 |
+
fn=gradio_demo,
|
575 |
+
inputs=[
|
576 |
+
target_dir_output,
|
577 |
+
conf_thres,
|
578 |
+
frame_filter,
|
579 |
+
mask_black_bg,
|
580 |
+
mask_white_bg,
|
581 |
+
show_cam,
|
582 |
+
mask_sky,
|
583 |
+
prediction_mode,
|
584 |
+
],
|
585 |
+
outputs=[reconstruction_output, log_output, frame_filter],
|
586 |
+
).then(
|
587 |
+
fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
|
588 |
+
)
|
589 |
+
|
590 |
+
# -------------------------------------------------------------------------
|
591 |
+
# Real-time Visualization Updates
|
592 |
+
# -------------------------------------------------------------------------
|
593 |
+
conf_thres.change(
|
594 |
+
update_visualization,
|
595 |
+
[
|
596 |
+
target_dir_output,
|
597 |
+
conf_thres,
|
598 |
+
frame_filter,
|
599 |
+
mask_black_bg,
|
600 |
+
mask_white_bg,
|
601 |
+
show_cam,
|
602 |
+
mask_sky,
|
603 |
+
prediction_mode,
|
604 |
+
is_example,
|
605 |
+
],
|
606 |
+
[reconstruction_output, log_output],
|
607 |
+
)
|
608 |
+
frame_filter.change(
|
609 |
+
update_visualization,
|
610 |
+
[
|
611 |
+
target_dir_output,
|
612 |
+
conf_thres,
|
613 |
+
frame_filter,
|
614 |
+
mask_black_bg,
|
615 |
+
mask_white_bg,
|
616 |
+
show_cam,
|
617 |
+
mask_sky,
|
618 |
+
prediction_mode,
|
619 |
+
is_example,
|
620 |
+
],
|
621 |
+
[reconstruction_output, log_output],
|
622 |
+
)
|
623 |
+
mask_black_bg.change(
|
624 |
+
update_visualization,
|
625 |
+
[
|
626 |
+
target_dir_output,
|
627 |
+
conf_thres,
|
628 |
+
frame_filter,
|
629 |
+
mask_black_bg,
|
630 |
+
mask_white_bg,
|
631 |
+
show_cam,
|
632 |
+
mask_sky,
|
633 |
+
prediction_mode,
|
634 |
+
is_example,
|
635 |
+
],
|
636 |
+
[reconstruction_output, log_output],
|
637 |
+
)
|
638 |
+
mask_white_bg.change(
|
639 |
+
update_visualization,
|
640 |
+
[
|
641 |
+
target_dir_output,
|
642 |
+
conf_thres,
|
643 |
+
frame_filter,
|
644 |
+
mask_black_bg,
|
645 |
+
mask_white_bg,
|
646 |
+
show_cam,
|
647 |
+
mask_sky,
|
648 |
+
prediction_mode,
|
649 |
+
is_example,
|
650 |
+
],
|
651 |
+
[reconstruction_output, log_output],
|
652 |
+
)
|
653 |
+
show_cam.change(
|
654 |
+
update_visualization,
|
655 |
+
[
|
656 |
+
target_dir_output,
|
657 |
+
conf_thres,
|
658 |
+
frame_filter,
|
659 |
+
mask_black_bg,
|
660 |
+
mask_white_bg,
|
661 |
+
show_cam,
|
662 |
+
mask_sky,
|
663 |
+
prediction_mode,
|
664 |
+
is_example,
|
665 |
+
],
|
666 |
+
[reconstruction_output, log_output],
|
667 |
+
)
|
668 |
+
mask_sky.change(
|
669 |
+
update_visualization,
|
670 |
+
[
|
671 |
+
target_dir_output,
|
672 |
+
conf_thres,
|
673 |
+
frame_filter,
|
674 |
+
mask_black_bg,
|
675 |
+
mask_white_bg,
|
676 |
+
show_cam,
|
677 |
+
mask_sky,
|
678 |
+
prediction_mode,
|
679 |
+
is_example,
|
680 |
+
],
|
681 |
+
[reconstruction_output, log_output],
|
682 |
+
)
|
683 |
+
prediction_mode.change(
|
684 |
+
update_visualization,
|
685 |
+
[
|
686 |
+
target_dir_output,
|
687 |
+
conf_thres,
|
688 |
+
frame_filter,
|
689 |
+
mask_black_bg,
|
690 |
+
mask_white_bg,
|
691 |
+
show_cam,
|
692 |
+
mask_sky,
|
693 |
+
prediction_mode,
|
694 |
+
is_example,
|
695 |
+
],
|
696 |
+
[reconstruction_output, log_output],
|
697 |
+
)
|
698 |
|
699 |
+
# -------------------------------------------------------------------------
|
700 |
+
# Auto-update gallery whenever user uploads or changes their files
|
701 |
+
# -------------------------------------------------------------------------
|
702 |
+
input_video.change(
|
703 |
+
fn=update_gallery_on_upload,
|
704 |
+
inputs=[input_video, input_images],
|
705 |
+
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
|
706 |
+
)
|
707 |
+
input_images.change(
|
708 |
+
fn=update_gallery_on_upload,
|
709 |
+
inputs=[input_video, input_images],
|
710 |
+
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
|
711 |
+
)
|
712 |
|
713 |
+
demo.queue(max_size=20).launch(show_error=True, share=True)
|
|
example_building/0.jpg
ADDED
![]() |
example_building/1.jpg
ADDED
![]() |
example_building/2.jpg
ADDED
![]() |
example_building/3.jpg
ADDED
![]() |
example_building/4.jpg
ADDED
![]() |
example_building/5.jpg
ADDED
![]() |
example_building/6.jpg
ADDED
![]() |
example_building/7.jpg
ADDED
![]() |
example_building/8.jpg
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.3.1
|
2 |
+
torchvision==0.18.1
|
3 |
+
numpy==1.26.1
|
4 |
+
Pillow
|
5 |
+
huggingface_hub
|
6 |
+
einops
|
7 |
+
safetensors
|
8 |
+
gradio
|
9 |
+
viser==0.2.23
|
10 |
+
tqdm
|
11 |
+
hydra-core
|
12 |
+
omegaconf
|
13 |
+
opencv-python
|
14 |
+
scipy
|
15 |
+
onnxruntime
|
16 |
+
requests
|
17 |
+
trimesh
|
18 |
+
matplotlib
|
19 |
+
gradio_client
|
vggt/heads/__pycache__/camera_head.cpython-310.pyc
ADDED
Binary file (4.27 kB). View file
|
|
vggt/heads/__pycache__/camera_head.cpython-311.pyc
ADDED
Binary file (6.8 kB). View file
|
|
vggt/heads/__pycache__/camera_head.cpython-312.pyc
ADDED
Binary file (6.13 kB). View file
|
|
vggt/heads/__pycache__/dpt_head.cpython-310.pyc
ADDED
Binary file (12.6 kB). View file
|
|
vggt/heads/__pycache__/dpt_head.cpython-311.pyc
ADDED
Binary file (21.7 kB). View file
|
|
vggt/heads/__pycache__/dpt_head.cpython-312.pyc
ADDED
Binary file (20.3 kB). View file
|
|
vggt/heads/__pycache__/head_act.cpython-310.pyc
ADDED
Binary file (3.1 kB). View file
|
|
vggt/heads/__pycache__/head_act.cpython-311.pyc
ADDED
Binary file (4.87 kB). View file
|
|
vggt/heads/__pycache__/head_act.cpython-312.pyc
ADDED
Binary file (4.49 kB). View file
|
|
vggt/heads/__pycache__/track_head.cpython-310.pyc
ADDED
Binary file (3.42 kB). View file
|
|
vggt/heads/__pycache__/track_head.cpython-311.pyc
ADDED
Binary file (4.09 kB). View file
|
|
vggt/heads/__pycache__/track_head.cpython-312.pyc
ADDED
Binary file (3.81 kB). View file
|
|
vggt/heads/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (3.13 kB). View file
|
|
vggt/heads/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (4.62 kB). View file
|
|
vggt/heads/__pycache__/utils.cpython-312.pyc
ADDED
Binary file (4.51 kB). View file
|
|
vggt/heads/camera_head.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from vggt.layers import Mlp
|
15 |
+
from vggt.layers.block import Block
|
16 |
+
from vggt.heads.head_act import activate_pose
|
17 |
+
|
18 |
+
|
19 |
+
class CameraHead(nn.Module):
|
20 |
+
"""
|
21 |
+
CameraHead predicts camera parameters from token representations using iterative refinement.
|
22 |
+
|
23 |
+
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
dim_in: int = 2048,
|
29 |
+
trunk_depth: int = 4,
|
30 |
+
pose_encoding_type: str = "absT_quaR_FoV",
|
31 |
+
num_heads: int = 16,
|
32 |
+
mlp_ratio: int = 4,
|
33 |
+
init_values: float = 0.01,
|
34 |
+
trans_act: str = "linear",
|
35 |
+
quat_act: str = "linear",
|
36 |
+
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
if pose_encoding_type == "absT_quaR_FoV":
|
41 |
+
self.target_dim = 9
|
42 |
+
else:
|
43 |
+
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
|
44 |
+
|
45 |
+
self.trans_act = trans_act
|
46 |
+
self.quat_act = quat_act
|
47 |
+
self.fl_act = fl_act
|
48 |
+
self.trunk_depth = trunk_depth
|
49 |
+
|
50 |
+
# Build the trunk using a sequence of transformer blocks.
|
51 |
+
self.trunk = nn.Sequential(
|
52 |
+
*[
|
53 |
+
Block(
|
54 |
+
dim=dim_in,
|
55 |
+
num_heads=num_heads,
|
56 |
+
mlp_ratio=mlp_ratio,
|
57 |
+
init_values=init_values,
|
58 |
+
)
|
59 |
+
for _ in range(trunk_depth)
|
60 |
+
]
|
61 |
+
)
|
62 |
+
|
63 |
+
# Normalizations for camera token and trunk output.
|
64 |
+
self.token_norm = nn.LayerNorm(dim_in)
|
65 |
+
self.trunk_norm = nn.LayerNorm(dim_in)
|
66 |
+
|
67 |
+
# Learnable empty camera pose token.
|
68 |
+
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
69 |
+
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
70 |
+
|
71 |
+
# Module for producing modulation parameters: shift, scale, and a gate.
|
72 |
+
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
|
73 |
+
|
74 |
+
# Adaptive layer normalization without affine parameters.
|
75 |
+
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
76 |
+
self.pose_branch = Mlp(
|
77 |
+
in_features=dim_in,
|
78 |
+
hidden_features=dim_in // 2,
|
79 |
+
out_features=self.target_dim,
|
80 |
+
drop=0,
|
81 |
+
)
|
82 |
+
|
83 |
+
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
|
84 |
+
"""
|
85 |
+
Forward pass to predict camera parameters.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
aggregated_tokens_list (list): List of token tensors from the network;
|
89 |
+
the last tensor is used for prediction.
|
90 |
+
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
list: A list of predicted camera encodings (post-activation) from each iteration.
|
94 |
+
"""
|
95 |
+
# Use tokens from the last block for camera prediction.
|
96 |
+
tokens = aggregated_tokens_list[-1]
|
97 |
+
|
98 |
+
# Extract the camera tokens
|
99 |
+
pose_tokens = tokens[:, :, 0]
|
100 |
+
pose_tokens = self.token_norm(pose_tokens)
|
101 |
+
|
102 |
+
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
|
103 |
+
return pred_pose_enc_list
|
104 |
+
|
105 |
+
def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
|
106 |
+
"""
|
107 |
+
Iteratively refine camera pose predictions.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
|
111 |
+
num_iterations (int): Number of refinement iterations.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
list: List of activated camera encodings from each iteration.
|
115 |
+
"""
|
116 |
+
B, S, C = pose_tokens.shape # S is expected to be 1.
|
117 |
+
pred_pose_enc = None
|
118 |
+
pred_pose_enc_list = []
|
119 |
+
|
120 |
+
for _ in range(num_iterations):
|
121 |
+
# Use a learned empty pose for the first iteration.
|
122 |
+
if pred_pose_enc is None:
|
123 |
+
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
124 |
+
else:
|
125 |
+
# Detach the previous prediction to avoid backprop through time.
|
126 |
+
pred_pose_enc = pred_pose_enc.detach()
|
127 |
+
module_input = self.embed_pose(pred_pose_enc)
|
128 |
+
|
129 |
+
# Generate modulation parameters and split them into shift, scale, and gate components.
|
130 |
+
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
|
131 |
+
|
132 |
+
# Adaptive layer normalization and modulation.
|
133 |
+
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
|
134 |
+
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
135 |
+
|
136 |
+
pose_tokens_modulated = self.trunk(pose_tokens_modulated)
|
137 |
+
# Compute the delta update for the pose encoding.
|
138 |
+
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
|
139 |
+
|
140 |
+
if pred_pose_enc is None:
|
141 |
+
pred_pose_enc = pred_pose_enc_delta
|
142 |
+
else:
|
143 |
+
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
144 |
+
|
145 |
+
# Apply final activation functions for translation, quaternion, and field-of-view.
|
146 |
+
activated_pose = activate_pose(
|
147 |
+
pred_pose_enc,
|
148 |
+
trans_act=self.trans_act,
|
149 |
+
quat_act=self.quat_act,
|
150 |
+
fl_act=self.fl_act,
|
151 |
+
)
|
152 |
+
pred_pose_enc_list.append(activated_pose)
|
153 |
+
|
154 |
+
return pred_pose_enc_list
|
155 |
+
|
156 |
+
|
157 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
158 |
+
"""
|
159 |
+
Modulate the input tensor using scaling and shifting parameters.
|
160 |
+
"""
|
161 |
+
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
162 |
+
return x * (1 + scale) + shift
|
vggt/heads/dpt_head.py
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
|
9 |
+
|
10 |
+
|
11 |
+
import os
|
12 |
+
from typing import List, Dict, Tuple, Union
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from .head_act import activate_head
|
18 |
+
from .utils import create_uv_grid, position_grid_to_embed
|
19 |
+
|
20 |
+
|
21 |
+
class DPTHead(nn.Module):
|
22 |
+
"""
|
23 |
+
DPT Head for dense prediction tasks.
|
24 |
+
|
25 |
+
This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
|
26 |
+
(https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
|
27 |
+
backbone and produces dense predictions by fusing multi-scale features.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
dim_in (int): Input dimension (channels).
|
31 |
+
patch_size (int, optional): Patch size. Default is 14.
|
32 |
+
output_dim (int, optional): Number of output channels. Default is 4.
|
33 |
+
activation (str, optional): Activation type. Default is "inv_log".
|
34 |
+
conf_activation (str, optional): Confidence activation type. Default is "expp1".
|
35 |
+
features (int, optional): Feature channels for intermediate representations. Default is 256.
|
36 |
+
out_channels (List[int], optional): Output channels for each intermediate layer.
|
37 |
+
intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
|
38 |
+
pos_embed (bool, optional): Whether to use positional embedding. Default is True.
|
39 |
+
feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
|
40 |
+
down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
dim_in: int,
|
46 |
+
patch_size: int = 14,
|
47 |
+
output_dim: int = 4,
|
48 |
+
activation: str = "inv_log",
|
49 |
+
conf_activation: str = "expp1",
|
50 |
+
features: int = 256,
|
51 |
+
out_channels: List[int] = [256, 512, 1024, 1024],
|
52 |
+
intermediate_layer_idx: List[int] = [4, 11, 17, 23],
|
53 |
+
pos_embed: bool = True,
|
54 |
+
feature_only: bool = False,
|
55 |
+
down_ratio: int = 1,
|
56 |
+
) -> None:
|
57 |
+
super(DPTHead, self).__init__()
|
58 |
+
self.patch_size = patch_size
|
59 |
+
self.activation = activation
|
60 |
+
self.conf_activation = conf_activation
|
61 |
+
self.pos_embed = pos_embed
|
62 |
+
self.feature_only = feature_only
|
63 |
+
self.down_ratio = down_ratio
|
64 |
+
self.intermediate_layer_idx = intermediate_layer_idx
|
65 |
+
|
66 |
+
self.norm = nn.LayerNorm(dim_in)
|
67 |
+
|
68 |
+
# Projection layers for each output channel from tokens.
|
69 |
+
self.projects = nn.ModuleList(
|
70 |
+
[
|
71 |
+
nn.Conv2d(
|
72 |
+
in_channels=dim_in,
|
73 |
+
out_channels=oc,
|
74 |
+
kernel_size=1,
|
75 |
+
stride=1,
|
76 |
+
padding=0,
|
77 |
+
)
|
78 |
+
for oc in out_channels
|
79 |
+
]
|
80 |
+
)
|
81 |
+
|
82 |
+
# Resize layers for upsampling feature maps.
|
83 |
+
self.resize_layers = nn.ModuleList(
|
84 |
+
[
|
85 |
+
nn.ConvTranspose2d(
|
86 |
+
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
87 |
+
),
|
88 |
+
nn.ConvTranspose2d(
|
89 |
+
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
90 |
+
),
|
91 |
+
nn.Identity(),
|
92 |
+
nn.Conv2d(
|
93 |
+
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
94 |
+
),
|
95 |
+
]
|
96 |
+
)
|
97 |
+
|
98 |
+
self.scratch = _make_scratch(
|
99 |
+
out_channels,
|
100 |
+
features,
|
101 |
+
expand=False,
|
102 |
+
)
|
103 |
+
|
104 |
+
# Attach additional modules to scratch.
|
105 |
+
self.scratch.stem_transpose = None
|
106 |
+
self.scratch.refinenet1 = _make_fusion_block(features)
|
107 |
+
self.scratch.refinenet2 = _make_fusion_block(features)
|
108 |
+
self.scratch.refinenet3 = _make_fusion_block(features)
|
109 |
+
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
|
110 |
+
|
111 |
+
head_features_1 = features
|
112 |
+
head_features_2 = 32
|
113 |
+
|
114 |
+
if feature_only:
|
115 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
|
116 |
+
else:
|
117 |
+
self.scratch.output_conv1 = nn.Conv2d(
|
118 |
+
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
119 |
+
)
|
120 |
+
conv2_in_channels = head_features_1 // 2
|
121 |
+
|
122 |
+
self.scratch.output_conv2 = nn.Sequential(
|
123 |
+
nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
|
124 |
+
nn.ReLU(inplace=True),
|
125 |
+
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
|
126 |
+
)
|
127 |
+
|
128 |
+
def forward(
|
129 |
+
self,
|
130 |
+
aggregated_tokens_list: List[torch.Tensor],
|
131 |
+
images: torch.Tensor,
|
132 |
+
patch_start_idx: int,
|
133 |
+
frames_chunk_size: int = 8,
|
134 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
135 |
+
"""
|
136 |
+
Forward pass through the DPT head, supports processing by chunking frames.
|
137 |
+
Args:
|
138 |
+
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
139 |
+
images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
140 |
+
patch_start_idx (int): Starting index for patch tokens in the token sequence.
|
141 |
+
Used to separate patch tokens from other tokens (e.g., camera or register tokens).
|
142 |
+
frames_chunk_size (int, optional): Number of frames to process in each chunk.
|
143 |
+
If None or larger than S, all frames are processed at once. Default: 8.
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
Tensor or Tuple[Tensor, Tensor]:
|
147 |
+
- If feature_only=True: Feature maps with shape [B, S, C, H, W]
|
148 |
+
- Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
|
149 |
+
"""
|
150 |
+
B, S, _, H, W = images.shape
|
151 |
+
|
152 |
+
# If frames_chunk_size is not specified or greater than S, process all frames at once
|
153 |
+
if frames_chunk_size is None or frames_chunk_size >= S:
|
154 |
+
return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
|
155 |
+
|
156 |
+
# Otherwise, process frames in chunks to manage memory usage
|
157 |
+
assert frames_chunk_size > 0
|
158 |
+
|
159 |
+
# Process frames in batches
|
160 |
+
all_preds = []
|
161 |
+
all_conf = []
|
162 |
+
|
163 |
+
for frames_start_idx in range(0, S, frames_chunk_size):
|
164 |
+
frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
|
165 |
+
|
166 |
+
# Process batch of frames
|
167 |
+
if self.feature_only:
|
168 |
+
chunk_output = self._forward_impl(
|
169 |
+
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
|
170 |
+
)
|
171 |
+
all_preds.append(chunk_output)
|
172 |
+
else:
|
173 |
+
chunk_preds, chunk_conf = self._forward_impl(
|
174 |
+
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
|
175 |
+
)
|
176 |
+
all_preds.append(chunk_preds)
|
177 |
+
all_conf.append(chunk_conf)
|
178 |
+
|
179 |
+
# Concatenate results along the sequence dimension
|
180 |
+
if self.feature_only:
|
181 |
+
return torch.cat(all_preds, dim=1)
|
182 |
+
else:
|
183 |
+
return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
|
184 |
+
|
185 |
+
def _forward_impl(
|
186 |
+
self,
|
187 |
+
aggregated_tokens_list: List[torch.Tensor],
|
188 |
+
images: torch.Tensor,
|
189 |
+
patch_start_idx: int,
|
190 |
+
frames_start_idx: int = None,
|
191 |
+
frames_end_idx: int = None,
|
192 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
193 |
+
"""
|
194 |
+
Implementation of the forward pass through the DPT head.
|
195 |
+
|
196 |
+
This method processes a specific chunk of frames from the sequence.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
200 |
+
images (Tensor): Input images with shape [B, S, 3, H, W].
|
201 |
+
patch_start_idx (int): Starting index for patch tokens.
|
202 |
+
frames_start_idx (int, optional): Starting index for frames to process.
|
203 |
+
frames_end_idx (int, optional): Ending index for frames to process.
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
|
207 |
+
"""
|
208 |
+
if frames_start_idx is not None and frames_end_idx is not None:
|
209 |
+
images = images[:, frames_start_idx:frames_end_idx].contiguous()
|
210 |
+
|
211 |
+
B, S, _, H, W = images.shape
|
212 |
+
|
213 |
+
patch_h, patch_w = H // self.patch_size, W // self.patch_size
|
214 |
+
|
215 |
+
out = []
|
216 |
+
dpt_idx = 0
|
217 |
+
|
218 |
+
for layer_idx in self.intermediate_layer_idx:
|
219 |
+
x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
|
220 |
+
|
221 |
+
# Select frames if processing a chunk
|
222 |
+
if frames_start_idx is not None and frames_end_idx is not None:
|
223 |
+
x = x[:, frames_start_idx:frames_end_idx]
|
224 |
+
|
225 |
+
x = x.reshape(B * S, -1, x.shape[-1])
|
226 |
+
|
227 |
+
x = self.norm(x)
|
228 |
+
|
229 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
230 |
+
|
231 |
+
x = self.projects[dpt_idx](x)
|
232 |
+
if self.pos_embed:
|
233 |
+
x = self._apply_pos_embed(x, W, H)
|
234 |
+
x = self.resize_layers[dpt_idx](x)
|
235 |
+
|
236 |
+
out.append(x)
|
237 |
+
dpt_idx += 1
|
238 |
+
|
239 |
+
# Fuse features from multiple layers.
|
240 |
+
out = self.scratch_forward(out)
|
241 |
+
# Interpolate fused output to match target image resolution.
|
242 |
+
out = custom_interpolate(
|
243 |
+
out,
|
244 |
+
(int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
|
245 |
+
mode="bilinear",
|
246 |
+
align_corners=True,
|
247 |
+
)
|
248 |
+
|
249 |
+
if self.pos_embed:
|
250 |
+
out = self._apply_pos_embed(out, W, H)
|
251 |
+
|
252 |
+
if self.feature_only:
|
253 |
+
return out.reshape(B, S, *out.shape[1:])
|
254 |
+
|
255 |
+
out = self.scratch.output_conv2(out)
|
256 |
+
preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
|
257 |
+
|
258 |
+
preds = preds.reshape(B, S, *preds.shape[1:])
|
259 |
+
conf = conf.reshape(B, S, *conf.shape[1:])
|
260 |
+
return preds, conf
|
261 |
+
|
262 |
+
def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
263 |
+
"""
|
264 |
+
Apply positional embedding to tensor x.
|
265 |
+
"""
|
266 |
+
patch_w = x.shape[-1]
|
267 |
+
patch_h = x.shape[-2]
|
268 |
+
pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
269 |
+
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
|
270 |
+
pos_embed = pos_embed * ratio
|
271 |
+
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
272 |
+
return x + pos_embed
|
273 |
+
|
274 |
+
def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
|
275 |
+
"""
|
276 |
+
Forward pass through the fusion blocks.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
features (List[Tensor]): List of feature maps from different layers.
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
Tensor: Fused feature map.
|
283 |
+
"""
|
284 |
+
layer_1, layer_2, layer_3, layer_4 = features
|
285 |
+
|
286 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
287 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
288 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
289 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
290 |
+
|
291 |
+
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
292 |
+
del layer_4_rn, layer_4
|
293 |
+
|
294 |
+
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
|
295 |
+
del layer_3_rn, layer_3
|
296 |
+
|
297 |
+
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
|
298 |
+
del layer_2_rn, layer_2
|
299 |
+
|
300 |
+
out = self.scratch.refinenet1(out, layer_1_rn)
|
301 |
+
del layer_1_rn, layer_1
|
302 |
+
|
303 |
+
out = self.scratch.output_conv1(out)
|
304 |
+
return out
|
305 |
+
|
306 |
+
|
307 |
+
################################################################################
|
308 |
+
# Modules
|
309 |
+
################################################################################
|
310 |
+
|
311 |
+
|
312 |
+
def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
|
313 |
+
return FeatureFusionBlock(
|
314 |
+
features,
|
315 |
+
nn.ReLU(inplace=True),
|
316 |
+
deconv=False,
|
317 |
+
bn=False,
|
318 |
+
expand=False,
|
319 |
+
align_corners=True,
|
320 |
+
size=size,
|
321 |
+
has_residual=has_residual,
|
322 |
+
groups=groups,
|
323 |
+
)
|
324 |
+
|
325 |
+
|
326 |
+
def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
|
327 |
+
scratch = nn.Module()
|
328 |
+
out_shape1 = out_shape
|
329 |
+
out_shape2 = out_shape
|
330 |
+
out_shape3 = out_shape
|
331 |
+
if len(in_shape) >= 4:
|
332 |
+
out_shape4 = out_shape
|
333 |
+
|
334 |
+
if expand:
|
335 |
+
out_shape1 = out_shape
|
336 |
+
out_shape2 = out_shape * 2
|
337 |
+
out_shape3 = out_shape * 4
|
338 |
+
if len(in_shape) >= 4:
|
339 |
+
out_shape4 = out_shape * 8
|
340 |
+
|
341 |
+
scratch.layer1_rn = nn.Conv2d(
|
342 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
343 |
+
)
|
344 |
+
scratch.layer2_rn = nn.Conv2d(
|
345 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
346 |
+
)
|
347 |
+
scratch.layer3_rn = nn.Conv2d(
|
348 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
349 |
+
)
|
350 |
+
if len(in_shape) >= 4:
|
351 |
+
scratch.layer4_rn = nn.Conv2d(
|
352 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
353 |
+
)
|
354 |
+
return scratch
|
355 |
+
|
356 |
+
|
357 |
+
class ResidualConvUnit(nn.Module):
|
358 |
+
"""Residual convolution module."""
|
359 |
+
|
360 |
+
def __init__(self, features, activation, bn, groups=1):
|
361 |
+
"""Init.
|
362 |
+
|
363 |
+
Args:
|
364 |
+
features (int): number of features
|
365 |
+
"""
|
366 |
+
super().__init__()
|
367 |
+
|
368 |
+
self.bn = bn
|
369 |
+
self.groups = groups
|
370 |
+
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
371 |
+
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
372 |
+
|
373 |
+
self.norm1 = None
|
374 |
+
self.norm2 = None
|
375 |
+
|
376 |
+
self.activation = activation
|
377 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
378 |
+
|
379 |
+
def forward(self, x):
|
380 |
+
"""Forward pass.
|
381 |
+
|
382 |
+
Args:
|
383 |
+
x (tensor): input
|
384 |
+
|
385 |
+
Returns:
|
386 |
+
tensor: output
|
387 |
+
"""
|
388 |
+
|
389 |
+
out = self.activation(x)
|
390 |
+
out = self.conv1(out)
|
391 |
+
if self.norm1 is not None:
|
392 |
+
out = self.norm1(out)
|
393 |
+
|
394 |
+
out = self.activation(out)
|
395 |
+
out = self.conv2(out)
|
396 |
+
if self.norm2 is not None:
|
397 |
+
out = self.norm2(out)
|
398 |
+
|
399 |
+
return self.skip_add.add(out, x)
|
400 |
+
|
401 |
+
|
402 |
+
class FeatureFusionBlock(nn.Module):
|
403 |
+
"""Feature fusion block."""
|
404 |
+
|
405 |
+
def __init__(
|
406 |
+
self,
|
407 |
+
features,
|
408 |
+
activation,
|
409 |
+
deconv=False,
|
410 |
+
bn=False,
|
411 |
+
expand=False,
|
412 |
+
align_corners=True,
|
413 |
+
size=None,
|
414 |
+
has_residual=True,
|
415 |
+
groups=1,
|
416 |
+
):
|
417 |
+
"""Init.
|
418 |
+
|
419 |
+
Args:
|
420 |
+
features (int): number of features
|
421 |
+
"""
|
422 |
+
super(FeatureFusionBlock, self).__init__()
|
423 |
+
|
424 |
+
self.deconv = deconv
|
425 |
+
self.align_corners = align_corners
|
426 |
+
self.groups = groups
|
427 |
+
self.expand = expand
|
428 |
+
out_features = features
|
429 |
+
if self.expand == True:
|
430 |
+
out_features = features // 2
|
431 |
+
|
432 |
+
self.out_conv = nn.Conv2d(
|
433 |
+
features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
|
434 |
+
)
|
435 |
+
|
436 |
+
if has_residual:
|
437 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
438 |
+
|
439 |
+
self.has_residual = has_residual
|
440 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
441 |
+
|
442 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
443 |
+
self.size = size
|
444 |
+
|
445 |
+
def forward(self, *xs, size=None):
|
446 |
+
"""Forward pass.
|
447 |
+
|
448 |
+
Returns:
|
449 |
+
tensor: output
|
450 |
+
"""
|
451 |
+
output = xs[0]
|
452 |
+
|
453 |
+
if self.has_residual:
|
454 |
+
res = self.resConfUnit1(xs[1])
|
455 |
+
output = self.skip_add.add(output, res)
|
456 |
+
|
457 |
+
output = self.resConfUnit2(output)
|
458 |
+
|
459 |
+
if (size is None) and (self.size is None):
|
460 |
+
modifier = {"scale_factor": 2}
|
461 |
+
elif size is None:
|
462 |
+
modifier = {"size": self.size}
|
463 |
+
else:
|
464 |
+
modifier = {"size": size}
|
465 |
+
|
466 |
+
output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
467 |
+
output = self.out_conv(output)
|
468 |
+
|
469 |
+
return output
|
470 |
+
|
471 |
+
|
472 |
+
def custom_interpolate(
|
473 |
+
x: torch.Tensor,
|
474 |
+
size: Tuple[int, int] = None,
|
475 |
+
scale_factor: float = None,
|
476 |
+
mode: str = "bilinear",
|
477 |
+
align_corners: bool = True,
|
478 |
+
) -> torch.Tensor:
|
479 |
+
"""
|
480 |
+
Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
|
481 |
+
"""
|
482 |
+
if size is None:
|
483 |
+
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
484 |
+
|
485 |
+
INT_MAX = 1610612736
|
486 |
+
|
487 |
+
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
|
488 |
+
|
489 |
+
if input_elements > INT_MAX:
|
490 |
+
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
|
491 |
+
interpolated_chunks = [
|
492 |
+
nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
|
493 |
+
]
|
494 |
+
x = torch.cat(interpolated_chunks, dim=0)
|
495 |
+
return x.contiguous()
|
496 |
+
else:
|
497 |
+
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
vggt/heads/head_act.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
|
13 |
+
"""
|
14 |
+
Activate pose parameters with specified activation functions.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
|
18 |
+
trans_act: Activation type for translation component
|
19 |
+
quat_act: Activation type for quaternion component
|
20 |
+
fl_act: Activation type for focal length component
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
Activated pose parameters tensor
|
24 |
+
"""
|
25 |
+
T = pred_pose_enc[..., :3]
|
26 |
+
quat = pred_pose_enc[..., 3:7]
|
27 |
+
fl = pred_pose_enc[..., 7:] # or fov
|
28 |
+
|
29 |
+
T = base_pose_act(T, trans_act)
|
30 |
+
quat = base_pose_act(quat, quat_act)
|
31 |
+
fl = base_pose_act(fl, fl_act) # or fov
|
32 |
+
|
33 |
+
pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
|
34 |
+
|
35 |
+
return pred_pose_enc
|
36 |
+
|
37 |
+
|
38 |
+
def base_pose_act(pose_enc, act_type="linear"):
|
39 |
+
"""
|
40 |
+
Apply basic activation function to pose parameters.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
pose_enc: Tensor containing encoded pose parameters
|
44 |
+
act_type: Activation type ("linear", "inv_log", "exp", "relu")
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
Activated pose parameters
|
48 |
+
"""
|
49 |
+
if act_type == "linear":
|
50 |
+
return pose_enc
|
51 |
+
elif act_type == "inv_log":
|
52 |
+
return inverse_log_transform(pose_enc)
|
53 |
+
elif act_type == "exp":
|
54 |
+
return torch.exp(pose_enc)
|
55 |
+
elif act_type == "relu":
|
56 |
+
return F.relu(pose_enc)
|
57 |
+
else:
|
58 |
+
raise ValueError(f"Unknown act_type: {act_type}")
|
59 |
+
|
60 |
+
|
61 |
+
def activate_head(out, activation="norm_exp", conf_activation="expp1"):
|
62 |
+
"""
|
63 |
+
Process network output to extract 3D points and confidence values.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
out: Network output tensor (B, C, H, W)
|
67 |
+
activation: Activation type for 3D points
|
68 |
+
conf_activation: Activation type for confidence values
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
Tuple of (3D points tensor, confidence tensor)
|
72 |
+
"""
|
73 |
+
# Move channels from last dim to the 4th dimension => (B, H, W, C)
|
74 |
+
fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
|
75 |
+
|
76 |
+
# Split into xyz (first C-1 channels) and confidence (last channel)
|
77 |
+
xyz = fmap[:, :, :, :-1]
|
78 |
+
conf = fmap[:, :, :, -1]
|
79 |
+
|
80 |
+
if activation == "norm_exp":
|
81 |
+
d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
82 |
+
xyz_normed = xyz / d
|
83 |
+
pts3d = xyz_normed * torch.expm1(d)
|
84 |
+
elif activation == "norm":
|
85 |
+
pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
|
86 |
+
elif activation == "exp":
|
87 |
+
pts3d = torch.exp(xyz)
|
88 |
+
elif activation == "relu":
|
89 |
+
pts3d = F.relu(xyz)
|
90 |
+
elif activation == "inv_log":
|
91 |
+
pts3d = inverse_log_transform(xyz)
|
92 |
+
elif activation == "xy_inv_log":
|
93 |
+
xy, z = xyz.split([2, 1], dim=-1)
|
94 |
+
z = inverse_log_transform(z)
|
95 |
+
pts3d = torch.cat([xy * z, z], dim=-1)
|
96 |
+
elif activation == "sigmoid":
|
97 |
+
pts3d = torch.sigmoid(xyz)
|
98 |
+
elif activation == "linear":
|
99 |
+
pts3d = xyz
|
100 |
+
else:
|
101 |
+
raise ValueError(f"Unknown activation: {activation}")
|
102 |
+
|
103 |
+
if conf_activation == "expp1":
|
104 |
+
conf_out = 1 + conf.exp()
|
105 |
+
elif conf_activation == "expp0":
|
106 |
+
conf_out = conf.exp()
|
107 |
+
elif conf_activation == "sigmoid":
|
108 |
+
conf_out = torch.sigmoid(conf)
|
109 |
+
else:
|
110 |
+
raise ValueError(f"Unknown conf_activation: {conf_activation}")
|
111 |
+
|
112 |
+
return pts3d, conf_out
|
113 |
+
|
114 |
+
|
115 |
+
def inverse_log_transform(y):
|
116 |
+
"""
|
117 |
+
Apply inverse log transform: sign(y) * (exp(|y|) - 1)
|
118 |
+
|
119 |
+
Args:
|
120 |
+
y: Input tensor
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
Transformed tensor
|
124 |
+
"""
|
125 |
+
return torch.sign(y) * (torch.expm1(torch.abs(y)))
|
vggt/heads/track_head.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch.nn as nn
|
8 |
+
from .dpt_head import DPTHead
|
9 |
+
from .track_modules.base_track_predictor import BaseTrackerPredictor
|
10 |
+
|
11 |
+
|
12 |
+
class TrackHead(nn.Module):
|
13 |
+
"""
|
14 |
+
Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
|
15 |
+
The tracking is performed iteratively, refining predictions over multiple iterations.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
dim_in,
|
21 |
+
patch_size=14,
|
22 |
+
features=128,
|
23 |
+
iters=4,
|
24 |
+
predict_conf=True,
|
25 |
+
stride=2,
|
26 |
+
corr_levels=7,
|
27 |
+
corr_radius=4,
|
28 |
+
hidden_size=384,
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Initialize the TrackHead module.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
dim_in (int): Input dimension of tokens from the backbone.
|
35 |
+
patch_size (int): Size of image patches used in the vision transformer.
|
36 |
+
features (int): Number of feature channels in the feature extractor output.
|
37 |
+
iters (int): Number of refinement iterations for tracking predictions.
|
38 |
+
predict_conf (bool): Whether to predict confidence scores for tracked points.
|
39 |
+
stride (int): Stride value for the tracker predictor.
|
40 |
+
corr_levels (int): Number of correlation pyramid levels
|
41 |
+
corr_radius (int): Radius for correlation computation, controlling the search area.
|
42 |
+
hidden_size (int): Size of hidden layers in the tracker network.
|
43 |
+
"""
|
44 |
+
super().__init__()
|
45 |
+
|
46 |
+
self.patch_size = patch_size
|
47 |
+
|
48 |
+
# Feature extractor based on DPT architecture
|
49 |
+
# Processes tokens into feature maps for tracking
|
50 |
+
self.feature_extractor = DPTHead(
|
51 |
+
dim_in=dim_in,
|
52 |
+
patch_size=patch_size,
|
53 |
+
features=features,
|
54 |
+
feature_only=True, # Only output features, no activation
|
55 |
+
down_ratio=2, # Reduces spatial dimensions by factor of 2
|
56 |
+
pos_embed=False,
|
57 |
+
)
|
58 |
+
|
59 |
+
# Tracker module that predicts point trajectories
|
60 |
+
# Takes feature maps and predicts coordinates and visibility
|
61 |
+
self.tracker = BaseTrackerPredictor(
|
62 |
+
latent_dim=features, # Match the output_dim of feature extractor
|
63 |
+
predict_conf=predict_conf,
|
64 |
+
stride=stride,
|
65 |
+
corr_levels=corr_levels,
|
66 |
+
corr_radius=corr_radius,
|
67 |
+
hidden_size=hidden_size,
|
68 |
+
)
|
69 |
+
|
70 |
+
self.iters = iters
|
71 |
+
|
72 |
+
def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
|
73 |
+
"""
|
74 |
+
Forward pass of the TrackHead.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
aggregated_tokens_list (list): List of aggregated tokens from the backbone.
|
78 |
+
images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
|
79 |
+
B = batch size, S = sequence length.
|
80 |
+
patch_start_idx (int): Starting index for patch tokens.
|
81 |
+
query_points (torch.Tensor, optional): Initial query points to track.
|
82 |
+
If None, points are initialized by the tracker.
|
83 |
+
iters (int, optional): Number of refinement iterations. If None, uses self.iters.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
tuple:
|
87 |
+
- coord_preds (torch.Tensor): Predicted coordinates for tracked points.
|
88 |
+
- vis_scores (torch.Tensor): Visibility scores for tracked points.
|
89 |
+
- conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
|
90 |
+
"""
|
91 |
+
B, S, _, H, W = images.shape
|
92 |
+
|
93 |
+
# Extract features from tokens
|
94 |
+
# feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
|
95 |
+
feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
|
96 |
+
|
97 |
+
# Use default iterations if not specified
|
98 |
+
if iters is None:
|
99 |
+
iters = self.iters
|
100 |
+
|
101 |
+
# Perform tracking using the extracted features
|
102 |
+
coord_preds, vis_scores, conf_scores = self.tracker(
|
103 |
+
query_points=query_points,
|
104 |
+
fmaps=feature_maps,
|
105 |
+
iters=iters,
|
106 |
+
)
|
107 |
+
|
108 |
+
return coord_preds, vis_scores, conf_scores
|
vggt/heads/track_modules/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (149 Bytes). View file
|
|
vggt/heads/track_modules/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (174 Bytes). View file
|
|
vggt/heads/track_modules/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (162 Bytes). View file
|
|
vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc
ADDED
Binary file (4.26 kB). View file
|
|
vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-311.pyc
ADDED
Binary file (9.38 kB). View file
|
|
vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-312.pyc
ADDED
Binary file (8.76 kB). View file
|
|
vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc
ADDED
Binary file (6.57 kB). View file
|
|
vggt/heads/track_modules/__pycache__/blocks.cpython-311.pyc
ADDED
Binary file (12.9 kB). View file
|
|
vggt/heads/track_modules/__pycache__/blocks.cpython-312.pyc
ADDED
Binary file (11.6 kB). View file
|
|
vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc
ADDED
Binary file (5.26 kB). View file
|
|
vggt/heads/track_modules/__pycache__/modules.cpython-311.pyc
ADDED
Binary file (10 kB). View file
|
|
vggt/heads/track_modules/__pycache__/modules.cpython-312.pyc
ADDED
Binary file (8.79 kB). View file
|
|
vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (7.35 kB). View file
|
|
vggt/heads/track_modules/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (11.1 kB). View file
|
|
vggt/heads/track_modules/__pycache__/utils.cpython-312.pyc
ADDED
Binary file (10.4 kB). View file
|
|
vggt/heads/track_modules/base_track_predictor.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
|
12 |
+
from .blocks import EfficientUpdateFormer, CorrBlock
|
13 |
+
from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
|
14 |
+
from .modules import Mlp
|
15 |
+
|
16 |
+
|
17 |
+
class BaseTrackerPredictor(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
stride=1,
|
21 |
+
corr_levels=5,
|
22 |
+
corr_radius=4,
|
23 |
+
latent_dim=128,
|
24 |
+
hidden_size=384,
|
25 |
+
use_spaceatt=True,
|
26 |
+
depth=6,
|
27 |
+
max_scale=518,
|
28 |
+
predict_conf=True,
|
29 |
+
):
|
30 |
+
super(BaseTrackerPredictor, self).__init__()
|
31 |
+
"""
|
32 |
+
The base template to create a track predictor
|
33 |
+
|
34 |
+
Modified from https://github.com/facebookresearch/co-tracker/
|
35 |
+
and https://github.com/facebookresearch/vggsfm
|
36 |
+
"""
|
37 |
+
|
38 |
+
self.stride = stride
|
39 |
+
self.latent_dim = latent_dim
|
40 |
+
self.corr_levels = corr_levels
|
41 |
+
self.corr_radius = corr_radius
|
42 |
+
self.hidden_size = hidden_size
|
43 |
+
self.max_scale = max_scale
|
44 |
+
self.predict_conf = predict_conf
|
45 |
+
|
46 |
+
self.flows_emb_dim = latent_dim // 2
|
47 |
+
|
48 |
+
self.corr_mlp = Mlp(
|
49 |
+
in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
|
50 |
+
hidden_features=self.hidden_size,
|
51 |
+
out_features=self.latent_dim,
|
52 |
+
)
|
53 |
+
|
54 |
+
self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
|
55 |
+
|
56 |
+
self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
|
57 |
+
|
58 |
+
space_depth = depth if use_spaceatt else 0
|
59 |
+
time_depth = depth
|
60 |
+
|
61 |
+
self.updateformer = EfficientUpdateFormer(
|
62 |
+
space_depth=space_depth,
|
63 |
+
time_depth=time_depth,
|
64 |
+
input_dim=self.transformer_dim,
|
65 |
+
hidden_size=self.hidden_size,
|
66 |
+
output_dim=self.latent_dim + 2,
|
67 |
+
mlp_ratio=4.0,
|
68 |
+
add_space_attn=use_spaceatt,
|
69 |
+
)
|
70 |
+
|
71 |
+
self.fmap_norm = nn.LayerNorm(self.latent_dim)
|
72 |
+
self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
|
73 |
+
|
74 |
+
# A linear layer to update track feats at each iteration
|
75 |
+
self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
|
76 |
+
|
77 |
+
self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
|
78 |
+
|
79 |
+
if predict_conf:
|
80 |
+
self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
|
81 |
+
|
82 |
+
def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):
|
83 |
+
"""
|
84 |
+
query_points: B x N x 2, the number of batches, tracks, and xy
|
85 |
+
fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
|
86 |
+
note HH and WW is the size of feature maps instead of original images
|
87 |
+
"""
|
88 |
+
B, N, D = query_points.shape
|
89 |
+
B, S, C, HH, WW = fmaps.shape
|
90 |
+
|
91 |
+
assert D == 2, "Input points must be 2D coordinates"
|
92 |
+
|
93 |
+
# apply a layernorm to fmaps here
|
94 |
+
fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
|
95 |
+
fmaps = fmaps.permute(0, 1, 4, 2, 3)
|
96 |
+
|
97 |
+
# Scale the input query_points because we may downsample the images
|
98 |
+
# by down_ratio or self.stride
|
99 |
+
# e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
|
100 |
+
# its query_points should be query_points/4
|
101 |
+
if down_ratio > 1:
|
102 |
+
query_points = query_points / float(down_ratio)
|
103 |
+
|
104 |
+
query_points = query_points / float(self.stride)
|
105 |
+
|
106 |
+
# Init with coords as the query points
|
107 |
+
# It means the search will start from the position of query points at the reference frames
|
108 |
+
coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
|
109 |
+
|
110 |
+
# Sample/extract the features of the query points in the query frame
|
111 |
+
query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
|
112 |
+
|
113 |
+
# init track feats by query feats
|
114 |
+
track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
|
115 |
+
# back up the init coords
|
116 |
+
coords_backup = coords.clone()
|
117 |
+
|
118 |
+
fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
|
119 |
+
|
120 |
+
coord_preds = []
|
121 |
+
|
122 |
+
# Iterative Refinement
|
123 |
+
for _ in range(iters):
|
124 |
+
# Detach the gradients from the last iteration
|
125 |
+
# (in my experience, not very important for performance)
|
126 |
+
coords = coords.detach()
|
127 |
+
|
128 |
+
fcorrs = fcorr_fn.corr_sample(track_feats, coords)
|
129 |
+
|
130 |
+
corr_dim = fcorrs.shape[3]
|
131 |
+
fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
|
132 |
+
fcorrs_ = self.corr_mlp(fcorrs_)
|
133 |
+
|
134 |
+
# Movement of current coords relative to query points
|
135 |
+
flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
|
136 |
+
|
137 |
+
flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
|
138 |
+
|
139 |
+
# (In my trials, it is also okay to just add the flows_emb instead of concat)
|
140 |
+
flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
|
141 |
+
|
142 |
+
track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
|
143 |
+
|
144 |
+
# Concatenate them as the input for the transformers
|
145 |
+
transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
|
146 |
+
|
147 |
+
# 2D positional embed
|
148 |
+
# TODO: this can be much simplified
|
149 |
+
pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
|
150 |
+
sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
|
151 |
+
|
152 |
+
sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
|
153 |
+
|
154 |
+
x = transformer_input + sampled_pos_emb
|
155 |
+
|
156 |
+
# Add the query ref token to the track feats
|
157 |
+
query_ref_token = torch.cat(
|
158 |
+
[self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
|
159 |
+
)
|
160 |
+
x = x + query_ref_token.to(x.device).to(x.dtype)
|
161 |
+
|
162 |
+
# B, N, S, C
|
163 |
+
x = rearrange(x, "(b n) s d -> b n s d", b=B)
|
164 |
+
|
165 |
+
# Compute the delta coordinates and delta track features
|
166 |
+
delta, _ = self.updateformer(x)
|
167 |
+
|
168 |
+
# BN, S, C
|
169 |
+
delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
|
170 |
+
delta_coords_ = delta[:, :, :2]
|
171 |
+
delta_feats_ = delta[:, :, 2:]
|
172 |
+
|
173 |
+
track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
|
174 |
+
delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
|
175 |
+
|
176 |
+
# Update the track features
|
177 |
+
track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
|
178 |
+
|
179 |
+
track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
|
180 |
+
|
181 |
+
# B x S x N x 2
|
182 |
+
coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
|
183 |
+
|
184 |
+
# Force coord0 as query
|
185 |
+
# because we assume the query points should not be changed
|
186 |
+
coords[:, 0] = coords_backup[:, 0]
|
187 |
+
|
188 |
+
# The predicted tracks are in the original image scale
|
189 |
+
if down_ratio > 1:
|
190 |
+
coord_preds.append(coords * self.stride * down_ratio)
|
191 |
+
else:
|
192 |
+
coord_preds.append(coords * self.stride)
|
193 |
+
|
194 |
+
# B, S, N
|
195 |
+
vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
|
196 |
+
if apply_sigmoid:
|
197 |
+
vis_e = torch.sigmoid(vis_e)
|
198 |
+
|
199 |
+
if self.predict_conf:
|
200 |
+
conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
|
201 |
+
if apply_sigmoid:
|
202 |
+
conf_e = torch.sigmoid(conf_e)
|
203 |
+
else:
|
204 |
+
conf_e = None
|
205 |
+
|
206 |
+
if return_feat:
|
207 |
+
return coord_preds, vis_e, track_feats, query_track_feat, conf_e
|
208 |
+
else:
|
209 |
+
return coord_preds, vis_e, conf_e
|
vggt/heads/track_modules/blocks.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
# Modified from https://github.com/facebookresearch/co-tracker/
|
9 |
+
|
10 |
+
import math
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
from .utils import bilinear_sampler
|
16 |
+
from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
|
17 |
+
|
18 |
+
|
19 |
+
class EfficientUpdateFormer(nn.Module):
|
20 |
+
"""
|
21 |
+
Transformer model that updates track estimates.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
space_depth=6,
|
27 |
+
time_depth=6,
|
28 |
+
input_dim=320,
|
29 |
+
hidden_size=384,
|
30 |
+
num_heads=8,
|
31 |
+
output_dim=130,
|
32 |
+
mlp_ratio=4.0,
|
33 |
+
add_space_attn=True,
|
34 |
+
num_virtual_tracks=64,
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.out_channels = 2
|
39 |
+
self.num_heads = num_heads
|
40 |
+
self.hidden_size = hidden_size
|
41 |
+
self.add_space_attn = add_space_attn
|
42 |
+
|
43 |
+
# Add input LayerNorm before linear projection
|
44 |
+
self.input_norm = nn.LayerNorm(input_dim)
|
45 |
+
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
46 |
+
|
47 |
+
# Add output LayerNorm before final projection
|
48 |
+
self.output_norm = nn.LayerNorm(hidden_size)
|
49 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
50 |
+
self.num_virtual_tracks = num_virtual_tracks
|
51 |
+
|
52 |
+
if self.add_space_attn:
|
53 |
+
self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
|
54 |
+
else:
|
55 |
+
self.virual_tracks = None
|
56 |
+
|
57 |
+
self.time_blocks = nn.ModuleList(
|
58 |
+
[
|
59 |
+
AttnBlock(
|
60 |
+
hidden_size,
|
61 |
+
num_heads,
|
62 |
+
mlp_ratio=mlp_ratio,
|
63 |
+
attn_class=nn.MultiheadAttention,
|
64 |
+
)
|
65 |
+
for _ in range(time_depth)
|
66 |
+
]
|
67 |
+
)
|
68 |
+
|
69 |
+
if add_space_attn:
|
70 |
+
self.space_virtual_blocks = nn.ModuleList(
|
71 |
+
[
|
72 |
+
AttnBlock(
|
73 |
+
hidden_size,
|
74 |
+
num_heads,
|
75 |
+
mlp_ratio=mlp_ratio,
|
76 |
+
attn_class=nn.MultiheadAttention,
|
77 |
+
)
|
78 |
+
for _ in range(space_depth)
|
79 |
+
]
|
80 |
+
)
|
81 |
+
self.space_point2virtual_blocks = nn.ModuleList(
|
82 |
+
[CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
|
83 |
+
)
|
84 |
+
self.space_virtual2point_blocks = nn.ModuleList(
|
85 |
+
[CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
|
86 |
+
)
|
87 |
+
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
|
88 |
+
self.initialize_weights()
|
89 |
+
|
90 |
+
def initialize_weights(self):
|
91 |
+
def _basic_init(module):
|
92 |
+
if isinstance(module, nn.Linear):
|
93 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
94 |
+
if module.bias is not None:
|
95 |
+
nn.init.constant_(module.bias, 0)
|
96 |
+
torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
|
97 |
+
|
98 |
+
self.apply(_basic_init)
|
99 |
+
|
100 |
+
def forward(self, input_tensor, mask=None):
|
101 |
+
# Apply input LayerNorm
|
102 |
+
input_tensor = self.input_norm(input_tensor)
|
103 |
+
tokens = self.input_transform(input_tensor)
|
104 |
+
|
105 |
+
init_tokens = tokens
|
106 |
+
|
107 |
+
B, _, T, _ = tokens.shape
|
108 |
+
|
109 |
+
if self.add_space_attn:
|
110 |
+
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
|
111 |
+
tokens = torch.cat([tokens, virtual_tokens], dim=1)
|
112 |
+
|
113 |
+
_, N, _, _ = tokens.shape
|
114 |
+
|
115 |
+
j = 0
|
116 |
+
for i in range(len(self.time_blocks)):
|
117 |
+
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
|
118 |
+
|
119 |
+
time_tokens = self.time_blocks[i](time_tokens)
|
120 |
+
|
121 |
+
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
|
122 |
+
if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
|
123 |
+
space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
|
124 |
+
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
|
125 |
+
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
|
126 |
+
|
127 |
+
virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
|
128 |
+
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
|
129 |
+
point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
|
130 |
+
|
131 |
+
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
|
132 |
+
tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
|
133 |
+
j += 1
|
134 |
+
|
135 |
+
if self.add_space_attn:
|
136 |
+
tokens = tokens[:, : N - self.num_virtual_tracks]
|
137 |
+
|
138 |
+
tokens = tokens + init_tokens
|
139 |
+
|
140 |
+
# Apply output LayerNorm before final projection
|
141 |
+
tokens = self.output_norm(tokens)
|
142 |
+
flow = self.flow_head(tokens)
|
143 |
+
|
144 |
+
return flow, None
|
145 |
+
|
146 |
+
|
147 |
+
class CorrBlock:
|
148 |
+
def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
|
149 |
+
"""
|
150 |
+
Build a pyramid of feature maps from the input.
|
151 |
+
|
152 |
+
fmaps: Tensor (B, S, C, H, W)
|
153 |
+
num_levels: number of pyramid levels (each downsampled by factor 2)
|
154 |
+
radius: search radius for sampling correlation
|
155 |
+
multiple_track_feats: if True, split the target features per pyramid level
|
156 |
+
padding_mode: passed to grid_sample / bilinear_sampler
|
157 |
+
"""
|
158 |
+
B, S, C, H, W = fmaps.shape
|
159 |
+
self.S, self.C, self.H, self.W = S, C, H, W
|
160 |
+
self.num_levels = num_levels
|
161 |
+
self.radius = radius
|
162 |
+
self.padding_mode = padding_mode
|
163 |
+
self.multiple_track_feats = multiple_track_feats
|
164 |
+
|
165 |
+
# Build pyramid: each level is half the spatial resolution of the previous
|
166 |
+
self.fmaps_pyramid = [fmaps] # level 0 is full resolution
|
167 |
+
current_fmaps = fmaps
|
168 |
+
for i in range(num_levels - 1):
|
169 |
+
B, S, C, H, W = current_fmaps.shape
|
170 |
+
# Merge batch & sequence dimensions
|
171 |
+
current_fmaps = current_fmaps.reshape(B * S, C, H, W)
|
172 |
+
# Avg pool down by factor 2
|
173 |
+
current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
|
174 |
+
_, _, H_new, W_new = current_fmaps.shape
|
175 |
+
current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)
|
176 |
+
self.fmaps_pyramid.append(current_fmaps)
|
177 |
+
|
178 |
+
# Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.
|
179 |
+
# This grid is added to the (scaled) coordinate centroids.
|
180 |
+
r = self.radius
|
181 |
+
dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
|
182 |
+
dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
|
183 |
+
# delta: for every (dy,dx) displacement (i.e. Δx, Δy)
|
184 |
+
self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2)
|
185 |
+
|
186 |
+
def corr_sample(self, targets, coords):
|
187 |
+
"""
|
188 |
+
Instead of storing the entire correlation pyramid, we compute each level's correlation
|
189 |
+
volume, sample it immediately, then discard it. This saves GPU memory.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
targets: Tensor (B, S, N, C) — features for the current targets.
|
193 |
+
coords: Tensor (B, S, N, 2) — coordinates at full resolution.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
|
197 |
+
"""
|
198 |
+
B, S, N, C = targets.shape
|
199 |
+
|
200 |
+
# If you have multiple track features, split them per level.
|
201 |
+
if self.multiple_track_feats:
|
202 |
+
targets_split = torch.split(targets, C // self.num_levels, dim=-1)
|
203 |
+
|
204 |
+
out_pyramid = []
|
205 |
+
for i, fmaps in enumerate(self.fmaps_pyramid):
|
206 |
+
# Get current spatial resolution H, W for this pyramid level.
|
207 |
+
B, S, C, H, W = fmaps.shape
|
208 |
+
# Reshape feature maps for correlation computation:
|
209 |
+
# fmap2s: (B, S, C, H*W)
|
210 |
+
fmap2s = fmaps.view(B, S, C, H * W)
|
211 |
+
# Choose appropriate target features.
|
212 |
+
fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C)
|
213 |
+
|
214 |
+
# Compute correlation directly
|
215 |
+
corrs = compute_corr_level(fmap1, fmap2s, C)
|
216 |
+
corrs = corrs.view(B, S, N, H, W)
|
217 |
+
|
218 |
+
# Prepare sampling grid:
|
219 |
+
# Scale down the coordinates for the current level.
|
220 |
+
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
|
221 |
+
# Make sure our precomputed delta grid is on the same device/dtype.
|
222 |
+
delta_lvl = self.delta.to(coords.device).to(coords.dtype)
|
223 |
+
# Now the grid for grid_sample is:
|
224 |
+
# coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid)
|
225 |
+
coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
|
226 |
+
|
227 |
+
# Sample from the correlation volume using bilinear interpolation.
|
228 |
+
# We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
|
229 |
+
corrs_sampled = bilinear_sampler(
|
230 |
+
corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode
|
231 |
+
)
|
232 |
+
# The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
|
233 |
+
corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2)
|
234 |
+
out_pyramid.append(corrs_sampled)
|
235 |
+
|
236 |
+
# Concatenate all levels along the last dimension.
|
237 |
+
out = torch.cat(out_pyramid, dim=-1).contiguous()
|
238 |
+
return out
|
239 |
+
|
240 |
+
|
241 |
+
def compute_corr_level(fmap1, fmap2s, C):
|
242 |
+
# fmap1: (B, S, N, C)
|
243 |
+
# fmap2s: (B, S, C, H*W)
|
244 |
+
corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W)
|
245 |
+
corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W)
|
246 |
+
return corrs / math.sqrt(C)
|
vggt/heads/track_modules/modules.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from functools import partial
|
12 |
+
from typing import Callable
|
13 |
+
import collections
|
14 |
+
from torch import Tensor
|
15 |
+
from itertools import repeat
|
16 |
+
|
17 |
+
|
18 |
+
# From PyTorch internals
|
19 |
+
def _ntuple(n):
|
20 |
+
def parse(x):
|
21 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
22 |
+
return tuple(x)
|
23 |
+
return tuple(repeat(x, n))
|
24 |
+
|
25 |
+
return parse
|
26 |
+
|
27 |
+
|
28 |
+
def exists(val):
|
29 |
+
return val is not None
|
30 |
+
|
31 |
+
|
32 |
+
def default(val, d):
|
33 |
+
return val if exists(val) else d
|
34 |
+
|
35 |
+
|
36 |
+
to_2tuple = _ntuple(2)
|
37 |
+
|
38 |
+
|
39 |
+
class ResidualBlock(nn.Module):
|
40 |
+
"""
|
41 |
+
ResidualBlock: construct a block of two conv layers with residual connections
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
|
45 |
+
super(ResidualBlock, self).__init__()
|
46 |
+
|
47 |
+
self.conv1 = nn.Conv2d(
|
48 |
+
in_planes,
|
49 |
+
planes,
|
50 |
+
kernel_size=kernel_size,
|
51 |
+
padding=1,
|
52 |
+
stride=stride,
|
53 |
+
padding_mode="zeros",
|
54 |
+
)
|
55 |
+
self.conv2 = nn.Conv2d(
|
56 |
+
planes,
|
57 |
+
planes,
|
58 |
+
kernel_size=kernel_size,
|
59 |
+
padding=1,
|
60 |
+
padding_mode="zeros",
|
61 |
+
)
|
62 |
+
self.relu = nn.ReLU(inplace=True)
|
63 |
+
|
64 |
+
num_groups = planes // 8
|
65 |
+
|
66 |
+
if norm_fn == "group":
|
67 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
68 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
69 |
+
if not stride == 1:
|
70 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
71 |
+
|
72 |
+
elif norm_fn == "batch":
|
73 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
74 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
75 |
+
if not stride == 1:
|
76 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
77 |
+
|
78 |
+
elif norm_fn == "instance":
|
79 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
80 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
81 |
+
if not stride == 1:
|
82 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
83 |
+
|
84 |
+
elif norm_fn == "none":
|
85 |
+
self.norm1 = nn.Sequential()
|
86 |
+
self.norm2 = nn.Sequential()
|
87 |
+
if not stride == 1:
|
88 |
+
self.norm3 = nn.Sequential()
|
89 |
+
else:
|
90 |
+
raise NotImplementedError
|
91 |
+
|
92 |
+
if stride == 1:
|
93 |
+
self.downsample = None
|
94 |
+
else:
|
95 |
+
self.downsample = nn.Sequential(
|
96 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
|
97 |
+
self.norm3,
|
98 |
+
)
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
y = x
|
102 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
103 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
104 |
+
|
105 |
+
if self.downsample is not None:
|
106 |
+
x = self.downsample(x)
|
107 |
+
|
108 |
+
return self.relu(x + y)
|
109 |
+
|
110 |
+
|
111 |
+
class Mlp(nn.Module):
|
112 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
113 |
+
|
114 |
+
def __init__(
|
115 |
+
self,
|
116 |
+
in_features,
|
117 |
+
hidden_features=None,
|
118 |
+
out_features=None,
|
119 |
+
act_layer=nn.GELU,
|
120 |
+
norm_layer=None,
|
121 |
+
bias=True,
|
122 |
+
drop=0.0,
|
123 |
+
use_conv=False,
|
124 |
+
):
|
125 |
+
super().__init__()
|
126 |
+
out_features = out_features or in_features
|
127 |
+
hidden_features = hidden_features or in_features
|
128 |
+
bias = to_2tuple(bias)
|
129 |
+
drop_probs = to_2tuple(drop)
|
130 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
131 |
+
|
132 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
133 |
+
self.act = act_layer()
|
134 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
135 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
136 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
x = self.fc1(x)
|
140 |
+
x = self.act(x)
|
141 |
+
x = self.drop1(x)
|
142 |
+
x = self.fc2(x)
|
143 |
+
x = self.drop2(x)
|
144 |
+
return x
|
145 |
+
|
146 |
+
|
147 |
+
class AttnBlock(nn.Module):
|
148 |
+
def __init__(
|
149 |
+
self,
|
150 |
+
hidden_size,
|
151 |
+
num_heads,
|
152 |
+
attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
|
153 |
+
mlp_ratio=4.0,
|
154 |
+
**block_kwargs
|
155 |
+
):
|
156 |
+
"""
|
157 |
+
Self attention block
|
158 |
+
"""
|
159 |
+
super().__init__()
|
160 |
+
|
161 |
+
self.norm1 = nn.LayerNorm(hidden_size)
|
162 |
+
self.norm2 = nn.LayerNorm(hidden_size)
|
163 |
+
|
164 |
+
self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
|
165 |
+
|
166 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
167 |
+
|
168 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
|
169 |
+
|
170 |
+
def forward(self, x, mask=None):
|
171 |
+
# Prepare the mask for PyTorch's attention (it expects a different format)
|
172 |
+
# attn_mask = mask if mask is not None else None
|
173 |
+
# Normalize before attention
|
174 |
+
x = self.norm1(x)
|
175 |
+
|
176 |
+
# PyTorch's MultiheadAttention returns attn_output, attn_output_weights
|
177 |
+
# attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
|
178 |
+
|
179 |
+
attn_output, _ = self.attn(x, x, x)
|
180 |
+
|
181 |
+
# Add & Norm
|
182 |
+
x = x + attn_output
|
183 |
+
x = x + self.mlp(self.norm2(x))
|
184 |
+
return x
|
185 |
+
|
186 |
+
|
187 |
+
class CrossAttnBlock(nn.Module):
|
188 |
+
def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
|
189 |
+
"""
|
190 |
+
Cross attention block
|
191 |
+
"""
|
192 |
+
super().__init__()
|
193 |
+
|
194 |
+
self.norm1 = nn.LayerNorm(hidden_size)
|
195 |
+
self.norm_context = nn.LayerNorm(hidden_size)
|
196 |
+
self.norm2 = nn.LayerNorm(hidden_size)
|
197 |
+
|
198 |
+
self.cross_attn = nn.MultiheadAttention(
|
199 |
+
embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
|
200 |
+
)
|
201 |
+
|
202 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
203 |
+
|
204 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
|
205 |
+
|
206 |
+
def forward(self, x, context, mask=None):
|
207 |
+
# Normalize inputs
|
208 |
+
x = self.norm1(x)
|
209 |
+
context = self.norm_context(context)
|
210 |
+
|
211 |
+
# Apply cross attention
|
212 |
+
# Note: nn.MultiheadAttention returns attn_output, attn_output_weights
|
213 |
+
attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
|
214 |
+
|
215 |
+
# Add & Norm
|
216 |
+
x = x + attn_output
|
217 |
+
x = x + self.mlp(self.norm2(x))
|
218 |
+
return x
|
vggt/heads/track_modules/utils.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Modified from https://github.com/facebookresearch/vggsfm
|
8 |
+
# and https://github.com/facebookresearch/co-tracker/tree/main
|
9 |
+
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
from typing import Optional, Tuple, Union
|
16 |
+
|
17 |
+
|
18 |
+
def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
|
19 |
+
"""
|
20 |
+
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
|
21 |
+
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
|
22 |
+
Args:
|
23 |
+
- embed_dim: The embedding dimension.
|
24 |
+
- grid_size: The grid size.
|
25 |
+
Returns:
|
26 |
+
- pos_embed: The generated 2D positional embedding.
|
27 |
+
"""
|
28 |
+
if isinstance(grid_size, tuple):
|
29 |
+
grid_size_h, grid_size_w = grid_size
|
30 |
+
else:
|
31 |
+
grid_size_h = grid_size_w = grid_size
|
32 |
+
grid_h = torch.arange(grid_size_h, dtype=torch.float)
|
33 |
+
grid_w = torch.arange(grid_size_w, dtype=torch.float)
|
34 |
+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
35 |
+
grid = torch.stack(grid, dim=0)
|
36 |
+
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
37 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
38 |
+
if return_grid:
|
39 |
+
return (
|
40 |
+
pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2),
|
41 |
+
grid,
|
42 |
+
)
|
43 |
+
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
|
44 |
+
|
45 |
+
|
46 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
|
47 |
+
"""
|
48 |
+
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
- embed_dim: The embedding dimension.
|
52 |
+
- grid: The grid to generate the embedding from.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
- emb: The generated 2D positional embedding.
|
56 |
+
"""
|
57 |
+
assert embed_dim % 2 == 0
|
58 |
+
|
59 |
+
# use half of dimensions to encode grid_h
|
60 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
61 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
62 |
+
|
63 |
+
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
|
64 |
+
return emb
|
65 |
+
|
66 |
+
|
67 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
|
68 |
+
"""
|
69 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
- embed_dim: The embedding dimension.
|
73 |
+
- pos: The position to generate the embedding from.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
- emb: The generated 1D positional embedding.
|
77 |
+
"""
|
78 |
+
assert embed_dim % 2 == 0
|
79 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
80 |
+
omega /= embed_dim / 2.0
|
81 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
82 |
+
|
83 |
+
pos = pos.reshape(-1) # (M,)
|
84 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
85 |
+
|
86 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
87 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
88 |
+
|
89 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
90 |
+
return emb[None].float()
|
91 |
+
|
92 |
+
|
93 |
+
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
|
94 |
+
"""
|
95 |
+
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
- xy: The coordinates to generate the embedding from.
|
99 |
+
- C: The size of the embedding.
|
100 |
+
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
- pe: The generated 2D positional embedding.
|
104 |
+
"""
|
105 |
+
B, N, D = xy.shape
|
106 |
+
assert D == 2
|
107 |
+
|
108 |
+
x = xy[:, :, 0:1]
|
109 |
+
y = xy[:, :, 1:2]
|
110 |
+
div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
|
111 |
+
|
112 |
+
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
113 |
+
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
114 |
+
|
115 |
+
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
116 |
+
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
117 |
+
|
118 |
+
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
119 |
+
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
120 |
+
|
121 |
+
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
|
122 |
+
if cat_coords:
|
123 |
+
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
|
124 |
+
return pe
|
125 |
+
|
126 |
+
|
127 |
+
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
|
128 |
+
r"""Sample a tensor using bilinear interpolation
|
129 |
+
|
130 |
+
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
131 |
+
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
132 |
+
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
133 |
+
convention.
|
134 |
+
|
135 |
+
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
136 |
+
:math:`B` is the batch size, :math:`C` is the number of channels,
|
137 |
+
:math:`H` is the height of the image, and :math:`W` is the width of the
|
138 |
+
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
139 |
+
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
140 |
+
|
141 |
+
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
142 |
+
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
143 |
+
that in this case the order of the components is slightly different
|
144 |
+
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
145 |
+
|
146 |
+
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
147 |
+
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
148 |
+
left-most image pixel :math:`W-1` to the center of the right-most
|
149 |
+
pixel.
|
150 |
+
|
151 |
+
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
152 |
+
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
153 |
+
the left-most pixel :math:`W` to the right edge of the right-most
|
154 |
+
pixel.
|
155 |
+
|
156 |
+
Similar conventions apply to the :math:`y` for the range
|
157 |
+
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
158 |
+
:math:`[0,T-1]` and :math:`[0,T]`.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
input (Tensor): batch of input images.
|
162 |
+
coords (Tensor): batch of coordinates.
|
163 |
+
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
164 |
+
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
Tensor: sampled points.
|
168 |
+
"""
|
169 |
+
coords = coords.detach().clone()
|
170 |
+
############################################################
|
171 |
+
# IMPORTANT:
|
172 |
+
coords = coords.to(input.device).to(input.dtype)
|
173 |
+
############################################################
|
174 |
+
|
175 |
+
sizes = input.shape[2:]
|
176 |
+
|
177 |
+
assert len(sizes) in [2, 3]
|
178 |
+
|
179 |
+
if len(sizes) == 3:
|
180 |
+
# t x y -> x y t to match dimensions T H W in grid_sample
|
181 |
+
coords = coords[..., [1, 2, 0]]
|
182 |
+
|
183 |
+
if align_corners:
|
184 |
+
scale = torch.tensor(
|
185 |
+
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype
|
186 |
+
)
|
187 |
+
else:
|
188 |
+
scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)
|
189 |
+
|
190 |
+
coords.mul_(scale) # coords = coords * scale
|
191 |
+
coords.sub_(1) # coords = coords - 1
|
192 |
+
|
193 |
+
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
|
194 |
+
|
195 |
+
|
196 |
+
def sample_features4d(input, coords):
|
197 |
+
r"""Sample spatial features
|
198 |
+
|
199 |
+
`sample_features4d(input, coords)` samples the spatial features
|
200 |
+
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
201 |
+
|
202 |
+
The field is sampled at coordinates :attr:`coords` using bilinear
|
203 |
+
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
204 |
+
2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
205 |
+
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
206 |
+
|
207 |
+
The output tensor has one feature per point, and has shape :math:`(B,
|
208 |
+
R, C)`.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
input (Tensor): spatial features.
|
212 |
+
coords (Tensor): points.
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
Tensor: sampled features.
|
216 |
+
"""
|
217 |
+
|
218 |
+
B, _, _, _ = input.shape
|
219 |
+
|
220 |
+
# B R 2 -> B R 1 2
|
221 |
+
coords = coords.unsqueeze(2)
|
222 |
+
|
223 |
+
# B C R 1
|
224 |
+
feats = bilinear_sampler(input, coords)
|
225 |
+
|
226 |
+
return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
|