diff --git a/app.py b/app.py
index cbffdf1ba490e3ae1fb244c10909cccfa7652993..c9107c6dc1a20cbfad7b23ddf69711287b8c14e3 100644
--- a/app.py
+++ b/app.py
@@ -1,7 +1,713 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import cv2
+import torch
+import numpy as np
import gradio as gr
+import sys
+import shutil
+from datetime import datetime
+import glob
+import gc
+import time
+
+from visual_util import predictions_to_glb
+from vggt.models.vggt import VGGT
+from vggt.utils.load_fn import load_and_preprocess_images
+from vggt.utils.pose_enc import pose_encoding_to_extri_intri
+from vggt.utils.geometry import unproject_depth_map_to_point_map
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+print("Initializing and loading VGGT model...")
+# model = VGGT.from_pretrained("facebook/VGGT-1B") # another way to load the model
+
+model_path = "https://huggingface.co/lch01/StreamVGGT/blob/main/checkpoints.pth"
+model = VGGT(use_causal_global=True, use_distil=True)
+ckpt = torch.load(torch.hub.load_state_dict_from_url(model_path), map_location=device)
+model.load_state_dict(ckpt, strict=True)
+model = model.to(device)
+model.eval()
+del ckpt
+
+
+# -------------------------------------------------------------------------
+# 1) Core model inference
+# -------------------------------------------------------------------------
+def run_model(target_dir, model) -> dict:
+ """
+ Run the VGGT model on images in the 'target_dir/images' folder and return predictions.
+ """
+ print(f"Processing images from {target_dir}")
+
+ # Device check
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ if not torch.cuda.is_available():
+ raise ValueError("CUDA is not available. Check your environment.")
+
+ # Move model to device
+ model = model.to(device)
+ model.eval()
+
+ # Load and preprocess images
+ image_names = glob.glob(os.path.join(target_dir, "images", "*"))
+ image_names = sorted(image_names)
+ print(f"Found {len(image_names)} images")
+ if len(image_names) == 0:
+ raise ValueError("No images found. Check your upload.")
+
+ images = load_and_preprocess_images(image_names).to(device)
+ print(f"Preprocessed images shape: {images.shape}")
+
+ # Run inference
+ print("Running inference...")
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
+
+ with torch.no_grad():
+ with torch.cuda.amp.autocast(dtype=dtype):
+ predictions = model(images)
+
+ # Convert pose encoding to extrinsic and intrinsic matrices
+ print("Converting pose encoding to extrinsic and intrinsic matrices...")
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
+ predictions["extrinsic"] = extrinsic
+ predictions["intrinsic"] = intrinsic
+
+ # Convert tensors to numpy
+ for key in predictions.keys():
+ if isinstance(predictions[key], torch.Tensor):
+ predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension
+ predictions['pose_enc_list'] = None # remove pose_enc_list
+
+ # Generate world points from depth map
+ print("Computing world points from depth map...")
+ depth_map = predictions["depth"] # (S, H, W, 1)
+ world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"])
+ predictions["world_points_from_depth"] = world_points
+
+ # Clean up
+ torch.cuda.empty_cache()
+ return predictions
+
+
+# -------------------------------------------------------------------------
+# 2) Handle uploaded video/images --> produce target_dir + images
+# -------------------------------------------------------------------------
+def handle_uploads(input_video, input_images):
+ """
+ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
+ images or extracted frames from video into it. Return (target_dir, image_paths).
+ """
+ start_time = time.time()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # Create a unique folder name
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
+ target_dir = f"input_images_{timestamp}"
+ target_dir_images = os.path.join(target_dir, "images")
+
+ # Clean up if somehow that folder already exists
+ if os.path.exists(target_dir):
+ shutil.rmtree(target_dir)
+ os.makedirs(target_dir)
+ os.makedirs(target_dir_images)
+
+ image_paths = []
+
+ # --- Handle images ---
+ if input_images is not None:
+ for file_data in input_images:
+ if isinstance(file_data, dict) and "name" in file_data:
+ file_path = file_data["name"]
+ else:
+ file_path = file_data
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
+ shutil.copy(file_path, dst_path)
+ image_paths.append(dst_path)
+
+ # --- Handle video ---
+ if input_video is not None:
+ if isinstance(input_video, dict) and "name" in input_video:
+ video_path = input_video["name"]
+ else:
+ video_path = input_video
+
+ vs = cv2.VideoCapture(video_path)
+ fps = vs.get(cv2.CAP_PROP_FPS)
+ frame_interval = int(fps * 1) # 1 frame/sec
+
+ count = 0
+ video_frame_num = 0
+ while True:
+ gotit, frame = vs.read()
+ if not gotit:
+ break
+ count += 1
+ if count % frame_interval == 0:
+ image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
+ cv2.imwrite(image_path, frame)
+ image_paths.append(image_path)
+ video_frame_num += 1
+
+ # Sort final images for gallery
+ image_paths = sorted(image_paths)
+
+ end_time = time.time()
+ print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds")
+ return target_dir, image_paths
+
+
+# -------------------------------------------------------------------------
+# 3) Update gallery on upload
+# -------------------------------------------------------------------------
+def update_gallery_on_upload(input_video, input_images):
+ """
+ Whenever user uploads or changes files, immediately handle them
+ and show in the gallery. Return (target_dir, image_paths).
+ If nothing is uploaded, returns "None" and empty list.
+ """
+ if not input_video and not input_images:
+ return None, None, None, None
+ target_dir, image_paths = handle_uploads(input_video, input_images)
+ return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing."
+
+
+# -------------------------------------------------------------------------
+# 4) Reconstruction: uses the target_dir plus any viz parameters
+# -------------------------------------------------------------------------
+def gradio_demo(
+ target_dir,
+ conf_thres=3.0,
+ frame_filter="All",
+ mask_black_bg=False,
+ mask_white_bg=False,
+ show_cam=True,
+ mask_sky=False,
+ prediction_mode="Pointmap Regression",
+):
+ """
+ Perform reconstruction using the already-created target_dir/images.
+ """
+ if not os.path.isdir(target_dir) or target_dir == "None":
+ return None, "No valid target directory found. Please upload first.", None, None
+
+ start_time = time.time()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # Prepare frame_filter dropdown
+ target_dir_images = os.path.join(target_dir, "images")
+ all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
+ all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
+ frame_filter_choices = ["All"] + all_files
+
+ print("Running run_model...")
+ with torch.no_grad():
+ predictions = run_model(target_dir, model)
+
+ # Save predictions
+ prediction_save_path = os.path.join(target_dir, "predictions.npz")
+ np.savez(prediction_save_path, **predictions)
+
+ # Handle None frame_filter
+ if frame_filter is None:
+ frame_filter = "All"
+
+ # Build a GLB file name
+ glbfile = os.path.join(
+ target_dir,
+ f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
+ )
+
+ # Convert predictions to GLB
+ glbscene = predictions_to_glb(
+ predictions,
+ conf_thres=conf_thres,
+ filter_by_frames=frame_filter,
+ mask_black_bg=mask_black_bg,
+ mask_white_bg=mask_white_bg,
+ show_cam=show_cam,
+ mask_sky=mask_sky,
+ target_dir=target_dir,
+ prediction_mode=prediction_mode,
+ )
+ glbscene.export(file_obj=glbfile)
+
+ # Cleanup
+ del predictions
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ end_time = time.time()
+ print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
+ log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
+
+ return glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True)
+
+
+# -------------------------------------------------------------------------
+# 5) Helper functions for UI resets + re-visualization
+# -------------------------------------------------------------------------
+def clear_fields():
+ """
+ Clears the 3D viewer, the stored target_dir, and empties the gallery.
+ """
+ return None
+
+
+def update_log():
+ """
+ Display a quick log message while waiting.
+ """
+ return "Loading and Reconstructing..."
+
+
+def update_visualization(
+ target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example
+):
+ """
+ Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
+ and return it for the 3D viewer. If is_example == "True", skip.
+ """
+
+ # If it's an example click, skip as requested
+ if is_example == "True":
+ return None, "No reconstruction available. Please click the Reconstruct button first."
+
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
+ return None, "No reconstruction available. Please click the Reconstruct button first."
+
+ predictions_path = os.path.join(target_dir, "predictions.npz")
+ if not os.path.exists(predictions_path):
+ return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first."
+
+ key_list = [
+ "pose_enc",
+ "depth",
+ "depth_conf",
+ "world_points",
+ "world_points_conf",
+ "images",
+ "extrinsic",
+ "intrinsic",
+ "world_points_from_depth",
+ ]
+
+ loaded = np.load(predictions_path)
+ predictions = {key: np.array(loaded[key]) for key in key_list}
+
+ glbfile = os.path.join(
+ target_dir,
+ f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
+ )
+
+ if not os.path.exists(glbfile):
+ glbscene = predictions_to_glb(
+ predictions,
+ conf_thres=conf_thres,
+ filter_by_frames=frame_filter,
+ mask_black_bg=mask_black_bg,
+ mask_white_bg=mask_white_bg,
+ show_cam=show_cam,
+ mask_sky=mask_sky,
+ target_dir=target_dir,
+ prediction_mode=prediction_mode,
+ )
+ glbscene.export(file_obj=glbfile)
+
+ return glbfile, "Updating Visualization"
+
+# -------------------------------------------------------------------------
+# Example images
+# -------------------------------------------------------------------------
+
+
+def get_examples_from_folder(images_folder):
+ """
+ Create an example using all JPG/JPEG files from the specified folder.
+ No caching, directly uses the images from the folder.
+ """
+ examples = []
+
+ if not os.path.exists(images_folder):
+ print(f"Warning: Images folder {images_folder} does not exist.")
+ return examples
+
+ image_files = []
+ for ext in ['*.jpg', '*.jpeg', '*.JPG', '*.JPEG', '*.png', '*.PNG']:
+ image_files.extend(glob.glob(os.path.join(images_folder, ext)))
+
+ image_files = sorted(image_files)
+
+ if not image_files:
+ print(f"Warning: No images found in {images_folder}.")
+ return examples
+
+ num_images = len(image_files)
+ print(f"Found {num_images} images in {images_folder}")
+
+ example = [
+ None,
+ str(num_images),
+ image_files,
+ 20.0,
+ False,
+ False,
+ True,
+ False,
+ "Depthmap and Camera Branch",
+ "True"
+ ]
+
+ examples.append(example)
+ return examples
+
+building_folder = "example_building/"
+
+# -------------------------------------------------------------------------
+# 6) Build Gradio UI
+# -------------------------------------------------------------------------
+theme = gr.themes.Ocean()
+theme.set(
+ checkbox_label_background_fill_selected="*button_primary_background_fill",
+ checkbox_label_text_color_selected="*button_primary_text_color",
+)
+
+with gr.Blocks(
+ theme=theme,
+ css="""
+ .custom-log * {
+ font-style: italic;
+ font-size: 22px !important;
+ background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
+ -webkit-background-clip: text;
+ background-clip: text;
+ font-weight: bold !important;
+ color: transparent !important;
+ text-align: center !important;
+ }
+
+ .example-log * {
+ font-style: italic;
+ font-size: 16px !important;
+ background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
+ -webkit-background-clip: text;
+ background-clip: text;
+ color: transparent !important;
+ }
+
+ #my_radio .wrap {
+ display: flex;
+ flex-wrap: nowrap;
+ justify-content: center;
+ align-items: center;
+ }
+
+ #my_radio .wrap label {
+ display: flex;
+ width: 50%;
+ justify-content: center;
+ align-items: center;
+ margin: 0;
+ padding: 10px 0;
+ box-sizing: border-box;
+ }
+ """,
+) as demo:
+ # Instead of gr.State, we use a hidden Textbox:
+ is_example = gr.Textbox(label="is_example", visible=False, value="None")
+ num_images = gr.Textbox(label="num_images", visible=False, value="None")
+
+ gr.HTML(
+ """
+
🏛️ VGGT: Visual Geometry Grounded Transformer
+
+ 🐙 GitHub Repository |
+ Project Page
+
+
+
+
Upload a video or a set of images to create a 3D reconstruction of a scene or object. VGGT takes these images and generates a 3D point cloud, along with estimated camera poses.
+
+
Getting Started:
+
+ - Upload Your Data: Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).
+ - Preview: Your uploaded images will appear in the gallery on the left.
+ - Reconstruct: Click the "Reconstruct" button to start the 3D reconstruction process.
+ - Visualize: The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for a large number of input images.
+ -
+ Adjust Visualization (Optional):
+ After reconstruction, you can fine-tune the visualization using the options below
+
+ (click to expand):
+
+ - Confidence Threshold: Adjust the filtering of points based on confidence.
+ - Show Points from Frame: Select specific frames to display in the point cloud.
+ - Show Camera: Toggle the display of estimated camera positions.
+ - Filter Sky / Filter Black Background: Remove sky or black-background points.
+ - Select a Prediction Mode: Choose between "Depthmap and Camera Branch" or "Pointmap Branch."
+
+
+
+
+
Please note: VGGT typically reconstructs a scene in less than 1 second. However, visualizing 3D points may take tens of seconds due to third-party rendering, which are independent of VGGT's processing time.
+
+ """
+ )
+
+ target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
+
+ with gr.Row():
+ with gr.Column(scale=2):
+ input_video = gr.Video(label="Upload Video", interactive=True)
+ input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
+
+ image_gallery = gr.Gallery(
+ label="Preview",
+ columns=4,
+ height="300px",
+ show_download_button=True,
+ object_fit="contain",
+ preview=True,
+ )
+
+ with gr.Column(scale=4):
+ with gr.Column():
+ gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**")
+ log_output = gr.Markdown(
+ "Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"]
+ )
+ reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5)
+
+ with gr.Row():
+ submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
+ clear_btn = gr.ClearButton(
+ [input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery],
+ scale=1,
+ )
+
+ with gr.Row():
+ prediction_mode = gr.Radio(
+ ["Depthmap and Camera Branch", "Pointmap Branch"],
+ label="Select a Prediction Mode",
+ value="Depthmap and Camera Branch",
+ scale=1,
+ elem_id="my_radio",
+ )
+
+ with gr.Row():
+ conf_thres = gr.Slider(minimum=0, maximum=100, value=50, step=0.1, label="Confidence Threshold (%)")
+ frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame")
+ with gr.Column():
+ show_cam = gr.Checkbox(label="Show Camera", value=True)
+ mask_sky = gr.Checkbox(label="Filter Sky", value=False)
+ mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
+ mask_white_bg = gr.Checkbox(label="Filter White Background", value=False)
+
+ # ---------------------- Examples section ----------------------
+ examples = get_examples_from_folder(building_folder)
+
+ def example_pipeline(
+ input_video,
+ num_images_str,
+ input_images,
+ conf_thres,
+ mask_black_bg,
+ mask_white_bg,
+ show_cam,
+ mask_sky,
+ prediction_mode,
+ is_example_str,
+ ):
+ """
+ 1) Copy example images to new target_dir
+ 2) Reconstruct
+ 3) Return model3D + logs + new_dir + updated dropdown + gallery
+ We do NOT return is_example. It's just an input.
+ """
+ target_dir, image_paths = handle_uploads(input_video, input_images)
+ # Always use "All" for frame_filter in examples
+ frame_filter = "All"
+ glbfile, log_msg, dropdown = gradio_demo(
+ target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode
+ )
+ return glbfile, log_msg, target_dir, dropdown, image_paths
+
+ gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
+
+ gr.Examples(
+ examples=examples,
+ inputs=[
+ input_video,
+ num_images,
+ input_images,
+ conf_thres,
+ mask_black_bg,
+ mask_white_bg,
+ show_cam,
+ mask_sky,
+ prediction_mode,
+ is_example,
+ ],
+ outputs=[reconstruction_output, log_output, target_dir_output, frame_filter, image_gallery],
+ fn=example_pipeline,
+ cache_examples=False,
+ examples_per_page=50,
+ )
+
+ # -------------------------------------------------------------------------
+ # "Reconstruct" button logic:
+ # - Clear fields
+ # - Update log
+ # - gradio_demo(...) with the existing target_dir
+ # - Then set is_example = "False"
+ # -------------------------------------------------------------------------
+ submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
+ fn=update_log, inputs=[], outputs=[log_output]
+ ).then(
+ fn=gradio_demo,
+ inputs=[
+ target_dir_output,
+ conf_thres,
+ frame_filter,
+ mask_black_bg,
+ mask_white_bg,
+ show_cam,
+ mask_sky,
+ prediction_mode,
+ ],
+ outputs=[reconstruction_output, log_output, frame_filter],
+ ).then(
+ fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
+ )
+
+ # -------------------------------------------------------------------------
+ # Real-time Visualization Updates
+ # -------------------------------------------------------------------------
+ conf_thres.change(
+ update_visualization,
+ [
+ target_dir_output,
+ conf_thres,
+ frame_filter,
+ mask_black_bg,
+ mask_white_bg,
+ show_cam,
+ mask_sky,
+ prediction_mode,
+ is_example,
+ ],
+ [reconstruction_output, log_output],
+ )
+ frame_filter.change(
+ update_visualization,
+ [
+ target_dir_output,
+ conf_thres,
+ frame_filter,
+ mask_black_bg,
+ mask_white_bg,
+ show_cam,
+ mask_sky,
+ prediction_mode,
+ is_example,
+ ],
+ [reconstruction_output, log_output],
+ )
+ mask_black_bg.change(
+ update_visualization,
+ [
+ target_dir_output,
+ conf_thres,
+ frame_filter,
+ mask_black_bg,
+ mask_white_bg,
+ show_cam,
+ mask_sky,
+ prediction_mode,
+ is_example,
+ ],
+ [reconstruction_output, log_output],
+ )
+ mask_white_bg.change(
+ update_visualization,
+ [
+ target_dir_output,
+ conf_thres,
+ frame_filter,
+ mask_black_bg,
+ mask_white_bg,
+ show_cam,
+ mask_sky,
+ prediction_mode,
+ is_example,
+ ],
+ [reconstruction_output, log_output],
+ )
+ show_cam.change(
+ update_visualization,
+ [
+ target_dir_output,
+ conf_thres,
+ frame_filter,
+ mask_black_bg,
+ mask_white_bg,
+ show_cam,
+ mask_sky,
+ prediction_mode,
+ is_example,
+ ],
+ [reconstruction_output, log_output],
+ )
+ mask_sky.change(
+ update_visualization,
+ [
+ target_dir_output,
+ conf_thres,
+ frame_filter,
+ mask_black_bg,
+ mask_white_bg,
+ show_cam,
+ mask_sky,
+ prediction_mode,
+ is_example,
+ ],
+ [reconstruction_output, log_output],
+ )
+ prediction_mode.change(
+ update_visualization,
+ [
+ target_dir_output,
+ conf_thres,
+ frame_filter,
+ mask_black_bg,
+ mask_white_bg,
+ show_cam,
+ mask_sky,
+ prediction_mode,
+ is_example,
+ ],
+ [reconstruction_output, log_output],
+ )
-def greet(name):
- return "Hello " + name + "!!"
+ # -------------------------------------------------------------------------
+ # Auto-update gallery whenever user uploads or changes their files
+ # -------------------------------------------------------------------------
+ input_video.change(
+ fn=update_gallery_on_upload,
+ inputs=[input_video, input_images],
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
+ )
+ input_images.change(
+ fn=update_gallery_on_upload,
+ inputs=[input_video, input_images],
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
+ )
-demo = gr.Interface(fn=greet, inputs="text", outputs="text")
-demo.launch()
\ No newline at end of file
+ demo.queue(max_size=20).launch(show_error=True, share=True)
\ No newline at end of file
diff --git a/example_building/0.jpg b/example_building/0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..eef566fc7401c3ee724fb47b67281c2dcbcf4f22
Binary files /dev/null and b/example_building/0.jpg differ
diff --git a/example_building/1.jpg b/example_building/1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3db5318d47f7b53401c75f63db5dec355255078f
Binary files /dev/null and b/example_building/1.jpg differ
diff --git a/example_building/2.jpg b/example_building/2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f77bba74dab1a8c0defd078cedb63ee71101e6af
Binary files /dev/null and b/example_building/2.jpg differ
diff --git a/example_building/3.jpg b/example_building/3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..58fd4766c7f7107732db5a9b92556c5acf89c58e
Binary files /dev/null and b/example_building/3.jpg differ
diff --git a/example_building/4.jpg b/example_building/4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..067342b81c790dba147b04f32ed37f4cfde83ee1
Binary files /dev/null and b/example_building/4.jpg differ
diff --git a/example_building/5.jpg b/example_building/5.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9a40835c7ac3c93dbc7170671ae44c93ba437c7f
Binary files /dev/null and b/example_building/5.jpg differ
diff --git a/example_building/6.jpg b/example_building/6.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e72613d86397f728693a3b1c38d865ee543e0e23
Binary files /dev/null and b/example_building/6.jpg differ
diff --git a/example_building/7.jpg b/example_building/7.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..56b55732608a12f858c855ac34a734c5bfb9c80e
Binary files /dev/null and b/example_building/7.jpg differ
diff --git a/example_building/8.jpg b/example_building/8.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..95dcab172ddcbb9d04e04481dc1bd33561af489a
Binary files /dev/null and b/example_building/8.jpg differ
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0565c7c98fb56760c0f8b59adb98f526b832bf62
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,19 @@
+torch==2.3.1
+torchvision==0.18.1
+numpy==1.26.1
+Pillow
+huggingface_hub
+einops
+safetensors
+gradio
+viser==0.2.23
+tqdm
+hydra-core
+omegaconf
+opencv-python
+scipy
+onnxruntime
+requests
+trimesh
+matplotlib
+gradio_client
\ No newline at end of file
diff --git a/vggt/heads/__pycache__/camera_head.cpython-310.pyc b/vggt/heads/__pycache__/camera_head.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8cbfa972f550663138d0035dbbb4ae4af4576ca6
Binary files /dev/null and b/vggt/heads/__pycache__/camera_head.cpython-310.pyc differ
diff --git a/vggt/heads/__pycache__/camera_head.cpython-311.pyc b/vggt/heads/__pycache__/camera_head.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cc25b236c270c26df0c6703591904919638f018d
Binary files /dev/null and b/vggt/heads/__pycache__/camera_head.cpython-311.pyc differ
diff --git a/vggt/heads/__pycache__/camera_head.cpython-312.pyc b/vggt/heads/__pycache__/camera_head.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ec75f53e472e06c436761d154039b2a530bbd877
Binary files /dev/null and b/vggt/heads/__pycache__/camera_head.cpython-312.pyc differ
diff --git a/vggt/heads/__pycache__/dpt_head.cpython-310.pyc b/vggt/heads/__pycache__/dpt_head.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..43b807dd48cb6b3aa057833377975d59b6a52a90
Binary files /dev/null and b/vggt/heads/__pycache__/dpt_head.cpython-310.pyc differ
diff --git a/vggt/heads/__pycache__/dpt_head.cpython-311.pyc b/vggt/heads/__pycache__/dpt_head.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..87eff0fa249a0f2346d5dbd39454dd5012acf018
Binary files /dev/null and b/vggt/heads/__pycache__/dpt_head.cpython-311.pyc differ
diff --git a/vggt/heads/__pycache__/dpt_head.cpython-312.pyc b/vggt/heads/__pycache__/dpt_head.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca35ebc9f6a7fe354fdc01b02b21b7140a94b941
Binary files /dev/null and b/vggt/heads/__pycache__/dpt_head.cpython-312.pyc differ
diff --git a/vggt/heads/__pycache__/head_act.cpython-310.pyc b/vggt/heads/__pycache__/head_act.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93b79bdef56a06e2835e2283513c3aafa878aa99
Binary files /dev/null and b/vggt/heads/__pycache__/head_act.cpython-310.pyc differ
diff --git a/vggt/heads/__pycache__/head_act.cpython-311.pyc b/vggt/heads/__pycache__/head_act.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f3dee4ed1c0344e9155274d7eb5bbd1a15a7bf1b
Binary files /dev/null and b/vggt/heads/__pycache__/head_act.cpython-311.pyc differ
diff --git a/vggt/heads/__pycache__/head_act.cpython-312.pyc b/vggt/heads/__pycache__/head_act.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4cfecec8f876bd14796c689317e5777bd3bc06b0
Binary files /dev/null and b/vggt/heads/__pycache__/head_act.cpython-312.pyc differ
diff --git a/vggt/heads/__pycache__/track_head.cpython-310.pyc b/vggt/heads/__pycache__/track_head.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6af9058c87e1cbaaf0bba0196d64ede43e536966
Binary files /dev/null and b/vggt/heads/__pycache__/track_head.cpython-310.pyc differ
diff --git a/vggt/heads/__pycache__/track_head.cpython-311.pyc b/vggt/heads/__pycache__/track_head.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fffa3acbcda62f6c92f86aec2b364cbe8f76343f
Binary files /dev/null and b/vggt/heads/__pycache__/track_head.cpython-311.pyc differ
diff --git a/vggt/heads/__pycache__/track_head.cpython-312.pyc b/vggt/heads/__pycache__/track_head.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3899d6e65cb83f79bcf7de8218df5faff7898a5f
Binary files /dev/null and b/vggt/heads/__pycache__/track_head.cpython-312.pyc differ
diff --git a/vggt/heads/__pycache__/utils.cpython-310.pyc b/vggt/heads/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..249f94fb8f68ea69f1b76846fd11df7ee7e476a7
Binary files /dev/null and b/vggt/heads/__pycache__/utils.cpython-310.pyc differ
diff --git a/vggt/heads/__pycache__/utils.cpython-311.pyc b/vggt/heads/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..73f5259fa20f6838ed05414d1586023f0b546a8f
Binary files /dev/null and b/vggt/heads/__pycache__/utils.cpython-311.pyc differ
diff --git a/vggt/heads/__pycache__/utils.cpython-312.pyc b/vggt/heads/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d03b6aceb91c7abfea8397f546a57be4c60f726e
Binary files /dev/null and b/vggt/heads/__pycache__/utils.cpython-312.pyc differ
diff --git a/vggt/heads/camera_head.py b/vggt/heads/camera_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..176d76fb5baeb3a42fa3675a1d1fb14010f2904d
--- /dev/null
+++ b/vggt/heads/camera_head.py
@@ -0,0 +1,162 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from vggt.layers import Mlp
+from vggt.layers.block import Block
+from vggt.heads.head_act import activate_pose
+
+
+class CameraHead(nn.Module):
+ """
+ CameraHead predicts camera parameters from token representations using iterative refinement.
+
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
+ """
+
+ def __init__(
+ self,
+ dim_in: int = 2048,
+ trunk_depth: int = 4,
+ pose_encoding_type: str = "absT_quaR_FoV",
+ num_heads: int = 16,
+ mlp_ratio: int = 4,
+ init_values: float = 0.01,
+ trans_act: str = "linear",
+ quat_act: str = "linear",
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
+ ):
+ super().__init__()
+
+ if pose_encoding_type == "absT_quaR_FoV":
+ self.target_dim = 9
+ else:
+ raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
+
+ self.trans_act = trans_act
+ self.quat_act = quat_act
+ self.fl_act = fl_act
+ self.trunk_depth = trunk_depth
+
+ # Build the trunk using a sequence of transformer blocks.
+ self.trunk = nn.Sequential(
+ *[
+ Block(
+ dim=dim_in,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ init_values=init_values,
+ )
+ for _ in range(trunk_depth)
+ ]
+ )
+
+ # Normalizations for camera token and trunk output.
+ self.token_norm = nn.LayerNorm(dim_in)
+ self.trunk_norm = nn.LayerNorm(dim_in)
+
+ # Learnable empty camera pose token.
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
+
+ # Module for producing modulation parameters: shift, scale, and a gate.
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
+
+ # Adaptive layer normalization without affine parameters.
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
+ self.pose_branch = Mlp(
+ in_features=dim_in,
+ hidden_features=dim_in // 2,
+ out_features=self.target_dim,
+ drop=0,
+ )
+
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
+ """
+ Forward pass to predict camera parameters.
+
+ Args:
+ aggregated_tokens_list (list): List of token tensors from the network;
+ the last tensor is used for prediction.
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
+
+ Returns:
+ list: A list of predicted camera encodings (post-activation) from each iteration.
+ """
+ # Use tokens from the last block for camera prediction.
+ tokens = aggregated_tokens_list[-1]
+
+ # Extract the camera tokens
+ pose_tokens = tokens[:, :, 0]
+ pose_tokens = self.token_norm(pose_tokens)
+
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
+ return pred_pose_enc_list
+
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
+ """
+ Iteratively refine camera pose predictions.
+
+ Args:
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
+ num_iterations (int): Number of refinement iterations.
+
+ Returns:
+ list: List of activated camera encodings from each iteration.
+ """
+ B, S, C = pose_tokens.shape # S is expected to be 1.
+ pred_pose_enc = None
+ pred_pose_enc_list = []
+
+ for _ in range(num_iterations):
+ # Use a learned empty pose for the first iteration.
+ if pred_pose_enc is None:
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
+ else:
+ # Detach the previous prediction to avoid backprop through time.
+ pred_pose_enc = pred_pose_enc.detach()
+ module_input = self.embed_pose(pred_pose_enc)
+
+ # Generate modulation parameters and split them into shift, scale, and gate components.
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
+
+ # Adaptive layer normalization and modulation.
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
+
+ pose_tokens_modulated = self.trunk(pose_tokens_modulated)
+ # Compute the delta update for the pose encoding.
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
+
+ if pred_pose_enc is None:
+ pred_pose_enc = pred_pose_enc_delta
+ else:
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
+
+ # Apply final activation functions for translation, quaternion, and field-of-view.
+ activated_pose = activate_pose(
+ pred_pose_enc,
+ trans_act=self.trans_act,
+ quat_act=self.quat_act,
+ fl_act=self.fl_act,
+ )
+ pred_pose_enc_list.append(activated_pose)
+
+ return pred_pose_enc_list
+
+
+def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
+ """
+ Modulate the input tensor using scaling and shifting parameters.
+ """
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
+ return x * (1 + scale) + shift
diff --git a/vggt/heads/dpt_head.py b/vggt/heads/dpt_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc31b7da4589882b2dd7b52e47d3b30563bc9764
--- /dev/null
+++ b/vggt/heads/dpt_head.py
@@ -0,0 +1,497 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
+
+
+import os
+from typing import List, Dict, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .head_act import activate_head
+from .utils import create_uv_grid, position_grid_to_embed
+
+
+class DPTHead(nn.Module):
+ """
+ DPT Head for dense prediction tasks.
+
+ This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
+ (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
+ backbone and produces dense predictions by fusing multi-scale features.
+
+ Args:
+ dim_in (int): Input dimension (channels).
+ patch_size (int, optional): Patch size. Default is 14.
+ output_dim (int, optional): Number of output channels. Default is 4.
+ activation (str, optional): Activation type. Default is "inv_log".
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
+ out_channels (List[int], optional): Output channels for each intermediate layer.
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
+ """
+
+ def __init__(
+ self,
+ dim_in: int,
+ patch_size: int = 14,
+ output_dim: int = 4,
+ activation: str = "inv_log",
+ conf_activation: str = "expp1",
+ features: int = 256,
+ out_channels: List[int] = [256, 512, 1024, 1024],
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
+ pos_embed: bool = True,
+ feature_only: bool = False,
+ down_ratio: int = 1,
+ ) -> None:
+ super(DPTHead, self).__init__()
+ self.patch_size = patch_size
+ self.activation = activation
+ self.conf_activation = conf_activation
+ self.pos_embed = pos_embed
+ self.feature_only = feature_only
+ self.down_ratio = down_ratio
+ self.intermediate_layer_idx = intermediate_layer_idx
+
+ self.norm = nn.LayerNorm(dim_in)
+
+ # Projection layers for each output channel from tokens.
+ self.projects = nn.ModuleList(
+ [
+ nn.Conv2d(
+ in_channels=dim_in,
+ out_channels=oc,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ for oc in out_channels
+ ]
+ )
+
+ # Resize layers for upsampling feature maps.
+ self.resize_layers = nn.ModuleList(
+ [
+ nn.ConvTranspose2d(
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
+ ),
+ nn.ConvTranspose2d(
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
+ ),
+ nn.Identity(),
+ nn.Conv2d(
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
+ ),
+ ]
+ )
+
+ self.scratch = _make_scratch(
+ out_channels,
+ features,
+ expand=False,
+ )
+
+ # Attach additional modules to scratch.
+ self.scratch.stem_transpose = None
+ self.scratch.refinenet1 = _make_fusion_block(features)
+ self.scratch.refinenet2 = _make_fusion_block(features)
+ self.scratch.refinenet3 = _make_fusion_block(features)
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
+
+ head_features_1 = features
+ head_features_2 = 32
+
+ if feature_only:
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
+ else:
+ self.scratch.output_conv1 = nn.Conv2d(
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
+ )
+ conv2_in_channels = head_features_1 // 2
+
+ self.scratch.output_conv2 = nn.Sequential(
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
+ )
+
+ def forward(
+ self,
+ aggregated_tokens_list: List[torch.Tensor],
+ images: torch.Tensor,
+ patch_start_idx: int,
+ frames_chunk_size: int = 8,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Forward pass through the DPT head, supports processing by chunking frames.
+ Args:
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
+ If None or larger than S, all frames are processed at once. Default: 8.
+
+ Returns:
+ Tensor or Tuple[Tensor, Tensor]:
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
+ """
+ B, S, _, H, W = images.shape
+
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
+ if frames_chunk_size is None or frames_chunk_size >= S:
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
+
+ # Otherwise, process frames in chunks to manage memory usage
+ assert frames_chunk_size > 0
+
+ # Process frames in batches
+ all_preds = []
+ all_conf = []
+
+ for frames_start_idx in range(0, S, frames_chunk_size):
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
+
+ # Process batch of frames
+ if self.feature_only:
+ chunk_output = self._forward_impl(
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
+ )
+ all_preds.append(chunk_output)
+ else:
+ chunk_preds, chunk_conf = self._forward_impl(
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
+ )
+ all_preds.append(chunk_preds)
+ all_conf.append(chunk_conf)
+
+ # Concatenate results along the sequence dimension
+ if self.feature_only:
+ return torch.cat(all_preds, dim=1)
+ else:
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
+
+ def _forward_impl(
+ self,
+ aggregated_tokens_list: List[torch.Tensor],
+ images: torch.Tensor,
+ patch_start_idx: int,
+ frames_start_idx: int = None,
+ frames_end_idx: int = None,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Implementation of the forward pass through the DPT head.
+
+ This method processes a specific chunk of frames from the sequence.
+
+ Args:
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
+ images (Tensor): Input images with shape [B, S, 3, H, W].
+ patch_start_idx (int): Starting index for patch tokens.
+ frames_start_idx (int, optional): Starting index for frames to process.
+ frames_end_idx (int, optional): Ending index for frames to process.
+
+ Returns:
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
+ """
+ if frames_start_idx is not None and frames_end_idx is not None:
+ images = images[:, frames_start_idx:frames_end_idx].contiguous()
+
+ B, S, _, H, W = images.shape
+
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
+
+ out = []
+ dpt_idx = 0
+
+ for layer_idx in self.intermediate_layer_idx:
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
+
+ # Select frames if processing a chunk
+ if frames_start_idx is not None and frames_end_idx is not None:
+ x = x[:, frames_start_idx:frames_end_idx]
+
+ x = x.reshape(B * S, -1, x.shape[-1])
+
+ x = self.norm(x)
+
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
+
+ x = self.projects[dpt_idx](x)
+ if self.pos_embed:
+ x = self._apply_pos_embed(x, W, H)
+ x = self.resize_layers[dpt_idx](x)
+
+ out.append(x)
+ dpt_idx += 1
+
+ # Fuse features from multiple layers.
+ out = self.scratch_forward(out)
+ # Interpolate fused output to match target image resolution.
+ out = custom_interpolate(
+ out,
+ (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
+ mode="bilinear",
+ align_corners=True,
+ )
+
+ if self.pos_embed:
+ out = self._apply_pos_embed(out, W, H)
+
+ if self.feature_only:
+ return out.reshape(B, S, *out.shape[1:])
+
+ out = self.scratch.output_conv2(out)
+ preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
+
+ preds = preds.reshape(B, S, *preds.shape[1:])
+ conf = conf.reshape(B, S, *conf.shape[1:])
+ return preds, conf
+
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
+ """
+ Apply positional embedding to tensor x.
+ """
+ patch_w = x.shape[-1]
+ patch_h = x.shape[-2]
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
+ pos_embed = pos_embed * ratio
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
+ return x + pos_embed
+
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
+ """
+ Forward pass through the fusion blocks.
+
+ Args:
+ features (List[Tensor]): List of feature maps from different layers.
+
+ Returns:
+ Tensor: Fused feature map.
+ """
+ layer_1, layer_2, layer_3, layer_4 = features
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
+ del layer_4_rn, layer_4
+
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
+ del layer_3_rn, layer_3
+
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
+ del layer_2_rn, layer_2
+
+ out = self.scratch.refinenet1(out, layer_1_rn)
+ del layer_1_rn, layer_1
+
+ out = self.scratch.output_conv1(out)
+ return out
+
+
+################################################################################
+# Modules
+################################################################################
+
+
+def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
+ return FeatureFusionBlock(
+ features,
+ nn.ReLU(inplace=True),
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=size,
+ has_residual=has_residual,
+ groups=groups,
+ )
+
+
+def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
+ scratch = nn.Module()
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape
+
+ if expand:
+ out_shape1 = out_shape
+ out_shape2 = out_shape * 2
+ out_shape3 = out_shape * 4
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape * 8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ if len(in_shape) >= 4:
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ return scratch
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module."""
+
+ def __init__(self, features, activation, bn, groups=1):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+ self.groups = groups
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+
+ self.norm1 = None
+ self.norm2 = None
+
+ self.activation = activation
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.norm1 is not None:
+ out = self.norm1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.norm2 is not None:
+ out = self.norm2(out)
+
+ return self.skip_add.add(out, x)
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block."""
+
+ def __init__(
+ self,
+ features,
+ activation,
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=None,
+ has_residual=True,
+ groups=1,
+ ):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+ self.groups = groups
+ self.expand = expand
+ out_features = features
+ if self.expand == True:
+ out_features = features // 2
+
+ self.out_conv = nn.Conv2d(
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
+ )
+
+ if has_residual:
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
+
+ self.has_residual = has_residual
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+ self.size = size
+
+ def forward(self, *xs, size=None):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if self.has_residual:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+
+ output = self.resConfUnit2(output)
+
+ if (size is None) and (self.size is None):
+ modifier = {"scale_factor": 2}
+ elif size is None:
+ modifier = {"size": self.size}
+ else:
+ modifier = {"size": size}
+
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
+ output = self.out_conv(output)
+
+ return output
+
+
+def custom_interpolate(
+ x: torch.Tensor,
+ size: Tuple[int, int] = None,
+ scale_factor: float = None,
+ mode: str = "bilinear",
+ align_corners: bool = True,
+) -> torch.Tensor:
+ """
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
+ """
+ if size is None:
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
+
+ INT_MAX = 1610612736
+
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
+
+ if input_elements > INT_MAX:
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
+ interpolated_chunks = [
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
+ ]
+ x = torch.cat(interpolated_chunks, dim=0)
+ return x.contiguous()
+ else:
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
diff --git a/vggt/heads/head_act.py b/vggt/heads/head_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dedfcf1180a653dddc99623e60df625e5897489
--- /dev/null
+++ b/vggt/heads/head_act.py
@@ -0,0 +1,125 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import torch.nn.functional as F
+
+
+def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
+ """
+ Activate pose parameters with specified activation functions.
+
+ Args:
+ pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
+ trans_act: Activation type for translation component
+ quat_act: Activation type for quaternion component
+ fl_act: Activation type for focal length component
+
+ Returns:
+ Activated pose parameters tensor
+ """
+ T = pred_pose_enc[..., :3]
+ quat = pred_pose_enc[..., 3:7]
+ fl = pred_pose_enc[..., 7:] # or fov
+
+ T = base_pose_act(T, trans_act)
+ quat = base_pose_act(quat, quat_act)
+ fl = base_pose_act(fl, fl_act) # or fov
+
+ pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
+
+ return pred_pose_enc
+
+
+def base_pose_act(pose_enc, act_type="linear"):
+ """
+ Apply basic activation function to pose parameters.
+
+ Args:
+ pose_enc: Tensor containing encoded pose parameters
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
+
+ Returns:
+ Activated pose parameters
+ """
+ if act_type == "linear":
+ return pose_enc
+ elif act_type == "inv_log":
+ return inverse_log_transform(pose_enc)
+ elif act_type == "exp":
+ return torch.exp(pose_enc)
+ elif act_type == "relu":
+ return F.relu(pose_enc)
+ else:
+ raise ValueError(f"Unknown act_type: {act_type}")
+
+
+def activate_head(out, activation="norm_exp", conf_activation="expp1"):
+ """
+ Process network output to extract 3D points and confidence values.
+
+ Args:
+ out: Network output tensor (B, C, H, W)
+ activation: Activation type for 3D points
+ conf_activation: Activation type for confidence values
+
+ Returns:
+ Tuple of (3D points tensor, confidence tensor)
+ """
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
+
+ # Split into xyz (first C-1 channels) and confidence (last channel)
+ xyz = fmap[:, :, :, :-1]
+ conf = fmap[:, :, :, -1]
+
+ if activation == "norm_exp":
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
+ xyz_normed = xyz / d
+ pts3d = xyz_normed * torch.expm1(d)
+ elif activation == "norm":
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
+ elif activation == "exp":
+ pts3d = torch.exp(xyz)
+ elif activation == "relu":
+ pts3d = F.relu(xyz)
+ elif activation == "inv_log":
+ pts3d = inverse_log_transform(xyz)
+ elif activation == "xy_inv_log":
+ xy, z = xyz.split([2, 1], dim=-1)
+ z = inverse_log_transform(z)
+ pts3d = torch.cat([xy * z, z], dim=-1)
+ elif activation == "sigmoid":
+ pts3d = torch.sigmoid(xyz)
+ elif activation == "linear":
+ pts3d = xyz
+ else:
+ raise ValueError(f"Unknown activation: {activation}")
+
+ if conf_activation == "expp1":
+ conf_out = 1 + conf.exp()
+ elif conf_activation == "expp0":
+ conf_out = conf.exp()
+ elif conf_activation == "sigmoid":
+ conf_out = torch.sigmoid(conf)
+ else:
+ raise ValueError(f"Unknown conf_activation: {conf_activation}")
+
+ return pts3d, conf_out
+
+
+def inverse_log_transform(y):
+ """
+ Apply inverse log transform: sign(y) * (exp(|y|) - 1)
+
+ Args:
+ y: Input tensor
+
+ Returns:
+ Transformed tensor
+ """
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
diff --git a/vggt/heads/track_head.py b/vggt/heads/track_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ec7199bd185060989c236997f93b93f4fc77825
--- /dev/null
+++ b/vggt/heads/track_head.py
@@ -0,0 +1,108 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch.nn as nn
+from .dpt_head import DPTHead
+from .track_modules.base_track_predictor import BaseTrackerPredictor
+
+
+class TrackHead(nn.Module):
+ """
+ Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
+ The tracking is performed iteratively, refining predictions over multiple iterations.
+ """
+
+ def __init__(
+ self,
+ dim_in,
+ patch_size=14,
+ features=128,
+ iters=4,
+ predict_conf=True,
+ stride=2,
+ corr_levels=7,
+ corr_radius=4,
+ hidden_size=384,
+ ):
+ """
+ Initialize the TrackHead module.
+
+ Args:
+ dim_in (int): Input dimension of tokens from the backbone.
+ patch_size (int): Size of image patches used in the vision transformer.
+ features (int): Number of feature channels in the feature extractor output.
+ iters (int): Number of refinement iterations for tracking predictions.
+ predict_conf (bool): Whether to predict confidence scores for tracked points.
+ stride (int): Stride value for the tracker predictor.
+ corr_levels (int): Number of correlation pyramid levels
+ corr_radius (int): Radius for correlation computation, controlling the search area.
+ hidden_size (int): Size of hidden layers in the tracker network.
+ """
+ super().__init__()
+
+ self.patch_size = patch_size
+
+ # Feature extractor based on DPT architecture
+ # Processes tokens into feature maps for tracking
+ self.feature_extractor = DPTHead(
+ dim_in=dim_in,
+ patch_size=patch_size,
+ features=features,
+ feature_only=True, # Only output features, no activation
+ down_ratio=2, # Reduces spatial dimensions by factor of 2
+ pos_embed=False,
+ )
+
+ # Tracker module that predicts point trajectories
+ # Takes feature maps and predicts coordinates and visibility
+ self.tracker = BaseTrackerPredictor(
+ latent_dim=features, # Match the output_dim of feature extractor
+ predict_conf=predict_conf,
+ stride=stride,
+ corr_levels=corr_levels,
+ corr_radius=corr_radius,
+ hidden_size=hidden_size,
+ )
+
+ self.iters = iters
+
+ def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
+ """
+ Forward pass of the TrackHead.
+
+ Args:
+ aggregated_tokens_list (list): List of aggregated tokens from the backbone.
+ images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
+ B = batch size, S = sequence length.
+ patch_start_idx (int): Starting index for patch tokens.
+ query_points (torch.Tensor, optional): Initial query points to track.
+ If None, points are initialized by the tracker.
+ iters (int, optional): Number of refinement iterations. If None, uses self.iters.
+
+ Returns:
+ tuple:
+ - coord_preds (torch.Tensor): Predicted coordinates for tracked points.
+ - vis_scores (torch.Tensor): Visibility scores for tracked points.
+ - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
+ """
+ B, S, _, H, W = images.shape
+
+ # Extract features from tokens
+ # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
+ feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
+
+ # Use default iterations if not specified
+ if iters is None:
+ iters = self.iters
+
+ # Perform tracking using the extracted features
+ coord_preds, vis_scores, conf_scores = self.tracker(
+ query_points=query_points,
+ fmaps=feature_maps,
+ iters=iters,
+ )
+
+ return coord_preds, vis_scores, conf_scores
diff --git a/vggt/heads/track_modules/__init__.py b/vggt/heads/track_modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa
--- /dev/null
+++ b/vggt/heads/track_modules/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc b/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..06f349ee47e4ba67597cf329b74b69d72e952195
Binary files /dev/null and b/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc differ
diff --git a/vggt/heads/track_modules/__pycache__/__init__.cpython-311.pyc b/vggt/heads/track_modules/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bb0692535a41a87e1e0cc7fee86138188dc9d8fd
Binary files /dev/null and b/vggt/heads/track_modules/__pycache__/__init__.cpython-311.pyc differ
diff --git a/vggt/heads/track_modules/__pycache__/__init__.cpython-312.pyc b/vggt/heads/track_modules/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d04dbf9cd4f1161cd2460f9dcd079f3f3f7678f0
Binary files /dev/null and b/vggt/heads/track_modules/__pycache__/__init__.cpython-312.pyc differ
diff --git a/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc b/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3b05c9d2de1ea89ff4019c55ba6de1bdd41fa83
Binary files /dev/null and b/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc differ
diff --git a/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-311.pyc b/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cce940fa0bcdc1da9963812900a34047af2f3db6
Binary files /dev/null and b/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-311.pyc differ
diff --git a/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-312.pyc b/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ab63e60f1317c870103cdf56ab03b6fac022edaf
Binary files /dev/null and b/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-312.pyc differ
diff --git a/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc b/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f0b8b67f009e4e20a92530d6a8a7d2505c5ce187
Binary files /dev/null and b/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc differ
diff --git a/vggt/heads/track_modules/__pycache__/blocks.cpython-311.pyc b/vggt/heads/track_modules/__pycache__/blocks.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..67b4f7467baf6d4ce6eb44078aa68997c3ffcd43
Binary files /dev/null and b/vggt/heads/track_modules/__pycache__/blocks.cpython-311.pyc differ
diff --git a/vggt/heads/track_modules/__pycache__/blocks.cpython-312.pyc b/vggt/heads/track_modules/__pycache__/blocks.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2c5888fc278e7185ca78b53672a4b0cce248bdab
Binary files /dev/null and b/vggt/heads/track_modules/__pycache__/blocks.cpython-312.pyc differ
diff --git a/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc b/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cb5118633e842a9613b0a2bd8fd4fec6d31dff71
Binary files /dev/null and b/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc differ
diff --git a/vggt/heads/track_modules/__pycache__/modules.cpython-311.pyc b/vggt/heads/track_modules/__pycache__/modules.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..912ba2a6dc74ecea87cd8f231a009e9785e5fd73
Binary files /dev/null and b/vggt/heads/track_modules/__pycache__/modules.cpython-311.pyc differ
diff --git a/vggt/heads/track_modules/__pycache__/modules.cpython-312.pyc b/vggt/heads/track_modules/__pycache__/modules.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ffb4929483f83e8bc6ad2091d335f89e2e6d5df7
Binary files /dev/null and b/vggt/heads/track_modules/__pycache__/modules.cpython-312.pyc differ
diff --git a/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc b/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bfc767f5a8ed809ff18fddbd0ceb12e752fa5ba2
Binary files /dev/null and b/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc differ
diff --git a/vggt/heads/track_modules/__pycache__/utils.cpython-311.pyc b/vggt/heads/track_modules/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1f0b220ac55878862b8e4ec432abb3e94c17e30d
Binary files /dev/null and b/vggt/heads/track_modules/__pycache__/utils.cpython-311.pyc differ
diff --git a/vggt/heads/track_modules/__pycache__/utils.cpython-312.pyc b/vggt/heads/track_modules/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e0fba979c6305de05224706f9c3935fca9918c6d
Binary files /dev/null and b/vggt/heads/track_modules/__pycache__/utils.cpython-312.pyc differ
diff --git a/vggt/heads/track_modules/base_track_predictor.py b/vggt/heads/track_modules/base_track_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ce8ec4b66fff236e015d1bcaf85c8237a52be7a
--- /dev/null
+++ b/vggt/heads/track_modules/base_track_predictor.py
@@ -0,0 +1,209 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+
+
+from .blocks import EfficientUpdateFormer, CorrBlock
+from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
+from .modules import Mlp
+
+
+class BaseTrackerPredictor(nn.Module):
+ def __init__(
+ self,
+ stride=1,
+ corr_levels=5,
+ corr_radius=4,
+ latent_dim=128,
+ hidden_size=384,
+ use_spaceatt=True,
+ depth=6,
+ max_scale=518,
+ predict_conf=True,
+ ):
+ super(BaseTrackerPredictor, self).__init__()
+ """
+ The base template to create a track predictor
+
+ Modified from https://github.com/facebookresearch/co-tracker/
+ and https://github.com/facebookresearch/vggsfm
+ """
+
+ self.stride = stride
+ self.latent_dim = latent_dim
+ self.corr_levels = corr_levels
+ self.corr_radius = corr_radius
+ self.hidden_size = hidden_size
+ self.max_scale = max_scale
+ self.predict_conf = predict_conf
+
+ self.flows_emb_dim = latent_dim // 2
+
+ self.corr_mlp = Mlp(
+ in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
+ hidden_features=self.hidden_size,
+ out_features=self.latent_dim,
+ )
+
+ self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
+
+ self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
+
+ space_depth = depth if use_spaceatt else 0
+ time_depth = depth
+
+ self.updateformer = EfficientUpdateFormer(
+ space_depth=space_depth,
+ time_depth=time_depth,
+ input_dim=self.transformer_dim,
+ hidden_size=self.hidden_size,
+ output_dim=self.latent_dim + 2,
+ mlp_ratio=4.0,
+ add_space_attn=use_spaceatt,
+ )
+
+ self.fmap_norm = nn.LayerNorm(self.latent_dim)
+ self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
+
+ # A linear layer to update track feats at each iteration
+ self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
+
+ self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
+
+ if predict_conf:
+ self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
+
+ def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):
+ """
+ query_points: B x N x 2, the number of batches, tracks, and xy
+ fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
+ note HH and WW is the size of feature maps instead of original images
+ """
+ B, N, D = query_points.shape
+ B, S, C, HH, WW = fmaps.shape
+
+ assert D == 2, "Input points must be 2D coordinates"
+
+ # apply a layernorm to fmaps here
+ fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
+ fmaps = fmaps.permute(0, 1, 4, 2, 3)
+
+ # Scale the input query_points because we may downsample the images
+ # by down_ratio or self.stride
+ # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
+ # its query_points should be query_points/4
+ if down_ratio > 1:
+ query_points = query_points / float(down_ratio)
+
+ query_points = query_points / float(self.stride)
+
+ # Init with coords as the query points
+ # It means the search will start from the position of query points at the reference frames
+ coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
+
+ # Sample/extract the features of the query points in the query frame
+ query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
+
+ # init track feats by query feats
+ track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
+ # back up the init coords
+ coords_backup = coords.clone()
+
+ fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
+
+ coord_preds = []
+
+ # Iterative Refinement
+ for _ in range(iters):
+ # Detach the gradients from the last iteration
+ # (in my experience, not very important for performance)
+ coords = coords.detach()
+
+ fcorrs = fcorr_fn.corr_sample(track_feats, coords)
+
+ corr_dim = fcorrs.shape[3]
+ fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
+ fcorrs_ = self.corr_mlp(fcorrs_)
+
+ # Movement of current coords relative to query points
+ flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
+
+ flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
+
+ # (In my trials, it is also okay to just add the flows_emb instead of concat)
+ flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
+
+ track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
+
+ # Concatenate them as the input for the transformers
+ transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
+
+ # 2D positional embed
+ # TODO: this can be much simplified
+ pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
+ sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
+
+ sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
+
+ x = transformer_input + sampled_pos_emb
+
+ # Add the query ref token to the track feats
+ query_ref_token = torch.cat(
+ [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
+ )
+ x = x + query_ref_token.to(x.device).to(x.dtype)
+
+ # B, N, S, C
+ x = rearrange(x, "(b n) s d -> b n s d", b=B)
+
+ # Compute the delta coordinates and delta track features
+ delta, _ = self.updateformer(x)
+
+ # BN, S, C
+ delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
+ delta_coords_ = delta[:, :, :2]
+ delta_feats_ = delta[:, :, 2:]
+
+ track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
+ delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
+
+ # Update the track features
+ track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
+
+ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
+
+ # B x S x N x 2
+ coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
+
+ # Force coord0 as query
+ # because we assume the query points should not be changed
+ coords[:, 0] = coords_backup[:, 0]
+
+ # The predicted tracks are in the original image scale
+ if down_ratio > 1:
+ coord_preds.append(coords * self.stride * down_ratio)
+ else:
+ coord_preds.append(coords * self.stride)
+
+ # B, S, N
+ vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
+ if apply_sigmoid:
+ vis_e = torch.sigmoid(vis_e)
+
+ if self.predict_conf:
+ conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
+ if apply_sigmoid:
+ conf_e = torch.sigmoid(conf_e)
+ else:
+ conf_e = None
+
+ if return_feat:
+ return coord_preds, vis_e, track_feats, query_track_feat, conf_e
+ else:
+ return coord_preds, vis_e, conf_e
diff --git a/vggt/heads/track_modules/blocks.py b/vggt/heads/track_modules/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e7763f4fd8f515662421db192594380dbb574e5
--- /dev/null
+++ b/vggt/heads/track_modules/blocks.py
@@ -0,0 +1,246 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+# Modified from https://github.com/facebookresearch/co-tracker/
+
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .utils import bilinear_sampler
+from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
+
+
+class EfficientUpdateFormer(nn.Module):
+ """
+ Transformer model that updates track estimates.
+ """
+
+ def __init__(
+ self,
+ space_depth=6,
+ time_depth=6,
+ input_dim=320,
+ hidden_size=384,
+ num_heads=8,
+ output_dim=130,
+ mlp_ratio=4.0,
+ add_space_attn=True,
+ num_virtual_tracks=64,
+ ):
+ super().__init__()
+
+ self.out_channels = 2
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.add_space_attn = add_space_attn
+
+ # Add input LayerNorm before linear projection
+ self.input_norm = nn.LayerNorm(input_dim)
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
+
+ # Add output LayerNorm before final projection
+ self.output_norm = nn.LayerNorm(hidden_size)
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
+ self.num_virtual_tracks = num_virtual_tracks
+
+ if self.add_space_attn:
+ self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
+ else:
+ self.virual_tracks = None
+
+ self.time_blocks = nn.ModuleList(
+ [
+ AttnBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ attn_class=nn.MultiheadAttention,
+ )
+ for _ in range(time_depth)
+ ]
+ )
+
+ if add_space_attn:
+ self.space_virtual_blocks = nn.ModuleList(
+ [
+ AttnBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ attn_class=nn.MultiheadAttention,
+ )
+ for _ in range(space_depth)
+ ]
+ )
+ self.space_point2virtual_blocks = nn.ModuleList(
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
+ )
+ self.space_virtual2point_blocks = nn.ModuleList(
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
+ )
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
+
+ self.apply(_basic_init)
+
+ def forward(self, input_tensor, mask=None):
+ # Apply input LayerNorm
+ input_tensor = self.input_norm(input_tensor)
+ tokens = self.input_transform(input_tensor)
+
+ init_tokens = tokens
+
+ B, _, T, _ = tokens.shape
+
+ if self.add_space_attn:
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
+
+ _, N, _, _ = tokens.shape
+
+ j = 0
+ for i in range(len(self.time_blocks)):
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
+
+ time_tokens = self.time_blocks[i](time_tokens)
+
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
+ if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
+ space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
+
+ virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
+ point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
+
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
+ tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
+ j += 1
+
+ if self.add_space_attn:
+ tokens = tokens[:, : N - self.num_virtual_tracks]
+
+ tokens = tokens + init_tokens
+
+ # Apply output LayerNorm before final projection
+ tokens = self.output_norm(tokens)
+ flow = self.flow_head(tokens)
+
+ return flow, None
+
+
+class CorrBlock:
+ def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
+ """
+ Build a pyramid of feature maps from the input.
+
+ fmaps: Tensor (B, S, C, H, W)
+ num_levels: number of pyramid levels (each downsampled by factor 2)
+ radius: search radius for sampling correlation
+ multiple_track_feats: if True, split the target features per pyramid level
+ padding_mode: passed to grid_sample / bilinear_sampler
+ """
+ B, S, C, H, W = fmaps.shape
+ self.S, self.C, self.H, self.W = S, C, H, W
+ self.num_levels = num_levels
+ self.radius = radius
+ self.padding_mode = padding_mode
+ self.multiple_track_feats = multiple_track_feats
+
+ # Build pyramid: each level is half the spatial resolution of the previous
+ self.fmaps_pyramid = [fmaps] # level 0 is full resolution
+ current_fmaps = fmaps
+ for i in range(num_levels - 1):
+ B, S, C, H, W = current_fmaps.shape
+ # Merge batch & sequence dimensions
+ current_fmaps = current_fmaps.reshape(B * S, C, H, W)
+ # Avg pool down by factor 2
+ current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
+ _, _, H_new, W_new = current_fmaps.shape
+ current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)
+ self.fmaps_pyramid.append(current_fmaps)
+
+ # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.
+ # This grid is added to the (scaled) coordinate centroids.
+ r = self.radius
+ dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
+ dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
+ # delta: for every (dy,dx) displacement (i.e. Δx, Δy)
+ self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2)
+
+ def corr_sample(self, targets, coords):
+ """
+ Instead of storing the entire correlation pyramid, we compute each level's correlation
+ volume, sample it immediately, then discard it. This saves GPU memory.
+
+ Args:
+ targets: Tensor (B, S, N, C) — features for the current targets.
+ coords: Tensor (B, S, N, 2) — coordinates at full resolution.
+
+ Returns:
+ Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
+ """
+ B, S, N, C = targets.shape
+
+ # If you have multiple track features, split them per level.
+ if self.multiple_track_feats:
+ targets_split = torch.split(targets, C // self.num_levels, dim=-1)
+
+ out_pyramid = []
+ for i, fmaps in enumerate(self.fmaps_pyramid):
+ # Get current spatial resolution H, W for this pyramid level.
+ B, S, C, H, W = fmaps.shape
+ # Reshape feature maps for correlation computation:
+ # fmap2s: (B, S, C, H*W)
+ fmap2s = fmaps.view(B, S, C, H * W)
+ # Choose appropriate target features.
+ fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C)
+
+ # Compute correlation directly
+ corrs = compute_corr_level(fmap1, fmap2s, C)
+ corrs = corrs.view(B, S, N, H, W)
+
+ # Prepare sampling grid:
+ # Scale down the coordinates for the current level.
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
+ # Make sure our precomputed delta grid is on the same device/dtype.
+ delta_lvl = self.delta.to(coords.device).to(coords.dtype)
+ # Now the grid for grid_sample is:
+ # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid)
+ coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
+
+ # Sample from the correlation volume using bilinear interpolation.
+ # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
+ corrs_sampled = bilinear_sampler(
+ corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode
+ )
+ # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
+ corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2)
+ out_pyramid.append(corrs_sampled)
+
+ # Concatenate all levels along the last dimension.
+ out = torch.cat(out_pyramid, dim=-1).contiguous()
+ return out
+
+
+def compute_corr_level(fmap1, fmap2s, C):
+ # fmap1: (B, S, N, C)
+ # fmap2s: (B, S, C, H*W)
+ corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W)
+ corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W)
+ return corrs / math.sqrt(C)
diff --git a/vggt/heads/track_modules/modules.py b/vggt/heads/track_modules/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b090ddc4a9db01c8dd3564f9053e1ca9cdde93a
--- /dev/null
+++ b/vggt/heads/track_modules/modules.py
@@ -0,0 +1,218 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+from typing import Callable
+import collections
+from torch import Tensor
+from itertools import repeat
+
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ return val if exists(val) else d
+
+
+to_2tuple = _ntuple(2)
+
+
+class ResidualBlock(nn.Module):
+ """
+ ResidualBlock: construct a block of two conv layers with residual connections
+ """
+
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(
+ in_planes,
+ planes,
+ kernel_size=kernel_size,
+ padding=1,
+ stride=stride,
+ padding_mode="zeros",
+ )
+ self.conv2 = nn.Conv2d(
+ planes,
+ planes,
+ kernel_size=kernel_size,
+ padding=1,
+ padding_mode="zeros",
+ )
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == "group":
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not stride == 1:
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == "batch":
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.norm2 = nn.BatchNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == "instance":
+ self.norm1 = nn.InstanceNorm2d(planes)
+ self.norm2 = nn.InstanceNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == "none":
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ if not stride == 1:
+ self.norm3 = nn.Sequential()
+ else:
+ raise NotImplementedError
+
+ if stride == 1:
+ self.downsample = None
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
+ self.norm3,
+ )
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x + y)
+
+
+class Mlp(nn.Module):
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ norm_layer=None,
+ bias=True,
+ drop=0.0,
+ use_conv=False,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
+
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class AttnBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
+ mlp_ratio=4.0,
+ **block_kwargs
+ ):
+ """
+ Self attention block
+ """
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(hidden_size)
+ self.norm2 = nn.LayerNorm(hidden_size)
+
+ self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
+
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
+
+ def forward(self, x, mask=None):
+ # Prepare the mask for PyTorch's attention (it expects a different format)
+ # attn_mask = mask if mask is not None else None
+ # Normalize before attention
+ x = self.norm1(x)
+
+ # PyTorch's MultiheadAttention returns attn_output, attn_output_weights
+ # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
+
+ attn_output, _ = self.attn(x, x, x)
+
+ # Add & Norm
+ x = x + attn_output
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class CrossAttnBlock(nn.Module):
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
+ """
+ Cross attention block
+ """
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(hidden_size)
+ self.norm_context = nn.LayerNorm(hidden_size)
+ self.norm2 = nn.LayerNorm(hidden_size)
+
+ self.cross_attn = nn.MultiheadAttention(
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
+ )
+
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
+
+ def forward(self, x, context, mask=None):
+ # Normalize inputs
+ x = self.norm1(x)
+ context = self.norm_context(context)
+
+ # Apply cross attention
+ # Note: nn.MultiheadAttention returns attn_output, attn_output_weights
+ attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
+
+ # Add & Norm
+ x = x + attn_output
+ x = x + self.mlp(self.norm2(x))
+ return x
diff --git a/vggt/heads/track_modules/utils.py b/vggt/heads/track_modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..51d01d39cdc10388a04dab5db7cf409b31bde766
--- /dev/null
+++ b/vggt/heads/track_modules/utils.py
@@ -0,0 +1,226 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from https://github.com/facebookresearch/vggsfm
+# and https://github.com/facebookresearch/co-tracker/tree/main
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from typing import Optional, Tuple, Union
+
+
+def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
+ """
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
+ Args:
+ - embed_dim: The embedding dimension.
+ - grid_size: The grid size.
+ Returns:
+ - pos_embed: The generated 2D positional embedding.
+ """
+ if isinstance(grid_size, tuple):
+ grid_size_h, grid_size_w = grid_size
+ else:
+ grid_size_h = grid_size_w = grid_size
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
+ grid = torch.stack(grid, dim=0)
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if return_grid:
+ return (
+ pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2),
+ grid,
+ )
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
+ """
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - grid: The grid to generate the embedding from.
+
+ Returns:
+ - emb: The generated 2D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
+ """
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - pos: The position to generate the embedding from.
+
+ Returns:
+ - emb: The generated 1D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = torch.sin(out) # (M, D/2)
+ emb_cos = torch.cos(out) # (M, D/2)
+
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
+ return emb[None].float()
+
+
+def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
+ """
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
+
+ Args:
+ - xy: The coordinates to generate the embedding from.
+ - C: The size of the embedding.
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
+
+ Returns:
+ - pe: The generated 2D positional embedding.
+ """
+ B, N, D = xy.shape
+ assert D == 2
+
+ x = xy[:, :, 0:1]
+ y = xy[:, :, 1:2]
+ div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
+
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
+
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
+
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
+ if cat_coords:
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
+ return pe
+
+
+def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
+ r"""Sample a tensor using bilinear interpolation
+
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
+ convention.
+
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
+ :math:`B` is the batch size, :math:`C` is the number of channels,
+ :math:`H` is the height of the image, and :math:`W` is the width of the
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
+
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
+ that in this case the order of the components is slightly different
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
+
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
+ left-most image pixel :math:`W-1` to the center of the right-most
+ pixel.
+
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
+ the left-most pixel :math:`W` to the right edge of the right-most
+ pixel.
+
+ Similar conventions apply to the :math:`y` for the range
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
+ :math:`[0,T-1]` and :math:`[0,T]`.
+
+ Args:
+ input (Tensor): batch of input images.
+ coords (Tensor): batch of coordinates.
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
+
+ Returns:
+ Tensor: sampled points.
+ """
+ coords = coords.detach().clone()
+ ############################################################
+ # IMPORTANT:
+ coords = coords.to(input.device).to(input.dtype)
+ ############################################################
+
+ sizes = input.shape[2:]
+
+ assert len(sizes) in [2, 3]
+
+ if len(sizes) == 3:
+ # t x y -> x y t to match dimensions T H W in grid_sample
+ coords = coords[..., [1, 2, 0]]
+
+ if align_corners:
+ scale = torch.tensor(
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype
+ )
+ else:
+ scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)
+
+ coords.mul_(scale) # coords = coords * scale
+ coords.sub_(1) # coords = coords - 1
+
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
+
+
+def sample_features4d(input, coords):
+ r"""Sample spatial features
+
+ `sample_features4d(input, coords)` samples the spatial features
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
+
+ The field is sampled at coordinates :attr:`coords` using bilinear
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
+ 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
+
+ The output tensor has one feature per point, and has shape :math:`(B,
+ R, C)`.
+
+ Args:
+ input (Tensor): spatial features.
+ coords (Tensor): points.
+
+ Returns:
+ Tensor: sampled features.
+ """
+
+ B, _, _, _ = input.shape
+
+ # B R 2 -> B R 1 2
+ coords = coords.unsqueeze(2)
+
+ # B C R 1
+ feats = bilinear_sampler(input, coords)
+
+ return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
diff --git a/vggt/heads/utils.py b/vggt/heads/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7af1f68fa0ce0a48d11a708d53aa20aa8f78ba2
--- /dev/null
+++ b/vggt/heads/utils.py
@@ -0,0 +1,108 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+
+
+def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
+ """
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
+
+ Args:
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
+ embed_dim: Output channel dimension for embeddings
+
+ Returns:
+ Tensor of shape (H, W, embed_dim) with positional embeddings
+ """
+ H, W, grid_dim = pos_grid.shape
+ assert grid_dim == 2
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
+
+ # Process x and y coordinates separately
+ emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
+ emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
+
+ # Combine and reshape
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
+
+ return emb.view(H, W, embed_dim) # [H, W, D]
+
+
+def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
+ """
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - pos: The position to generate the embedding from.
+
+ Returns:
+ - emb: The generated 1D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+ omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / omega_0**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = torch.sin(out) # (M, D/2)
+ emb_cos = torch.cos(out) # (M, D/2)
+
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
+ return emb.float()
+
+
+# Inspired by https://github.com/microsoft/moge
+
+
+def create_uv_grid(
+ width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
+) -> torch.Tensor:
+ """
+ Create a normalized UV grid of shape (width, height, 2).
+
+ The grid spans horizontally and vertically according to an aspect ratio,
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
+
+ Args:
+ width (int): Number of points horizontally.
+ height (int): Number of points vertically.
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
+ device (torch.device, optional): Device on which the tensor is created.
+
+ Returns:
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
+ """
+ # Derive aspect ratio if not explicitly provided
+ if aspect_ratio is None:
+ aspect_ratio = float(width) / float(height)
+
+ # Compute normalized spans for X and Y
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
+ span_x = aspect_ratio / diag_factor
+ span_y = 1.0 / diag_factor
+
+ # Establish the linspace boundaries
+ left_x = -span_x * (width - 1) / width
+ right_x = span_x * (width - 1) / width
+ top_y = -span_y * (height - 1) / height
+ bottom_y = span_y * (height - 1) / height
+
+ # Generate 1D coordinates
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
+
+ # Create 2D meshgrid (width x height) and stack into UV
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
+ uv_grid = torch.stack((uu, vv), dim=-1)
+
+ return uv_grid
diff --git a/vggt/layers/__init__.py b/vggt/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1
--- /dev/null
+++ b/vggt/layers/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
diff --git a/vggt/layers/__pycache__/__init__.cpython-310.pyc b/vggt/layers/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7f928c0d84fcc9306a05350326386523b51c9dee
Binary files /dev/null and b/vggt/layers/__pycache__/__init__.cpython-310.pyc differ
diff --git a/vggt/layers/__pycache__/__init__.cpython-311.pyc b/vggt/layers/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e52510e608fd506e4d4b743ec0783996248da7e8
Binary files /dev/null and b/vggt/layers/__pycache__/__init__.cpython-311.pyc differ
diff --git a/vggt/layers/__pycache__/__init__.cpython-312.pyc b/vggt/layers/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c78ee56d649d02592480f9a936e056cacfac54fa
Binary files /dev/null and b/vggt/layers/__pycache__/__init__.cpython-312.pyc differ
diff --git a/vggt/layers/__pycache__/attention.cpython-310.pyc b/vggt/layers/__pycache__/attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f53c19f53def6e22dce3330bc0b6f45030a16af5
Binary files /dev/null and b/vggt/layers/__pycache__/attention.cpython-310.pyc differ
diff --git a/vggt/layers/__pycache__/attention.cpython-311.pyc b/vggt/layers/__pycache__/attention.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..01aa3d060f9cb8a587f388bf85842ef8c705f523
Binary files /dev/null and b/vggt/layers/__pycache__/attention.cpython-311.pyc differ
diff --git a/vggt/layers/__pycache__/attention.cpython-312.pyc b/vggt/layers/__pycache__/attention.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5c5a87b6ee9c79adf4db8e46a6d56e45c447d9a
Binary files /dev/null and b/vggt/layers/__pycache__/attention.cpython-312.pyc differ
diff --git a/vggt/layers/__pycache__/block.cpython-310.pyc b/vggt/layers/__pycache__/block.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f3d104e68e5daa6a1607f5ec05460857316ff67d
Binary files /dev/null and b/vggt/layers/__pycache__/block.cpython-310.pyc differ
diff --git a/vggt/layers/__pycache__/block.cpython-311.pyc b/vggt/layers/__pycache__/block.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..beb9ee3180c268edb3e84a623051c5cde5cebd33
Binary files /dev/null and b/vggt/layers/__pycache__/block.cpython-311.pyc differ
diff --git a/vggt/layers/__pycache__/block.cpython-312.pyc b/vggt/layers/__pycache__/block.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c2c0e4319ed4976018b3a8a3befb30d26f1f2e43
Binary files /dev/null and b/vggt/layers/__pycache__/block.cpython-312.pyc differ
diff --git a/vggt/layers/__pycache__/drop_path.cpython-310.pyc b/vggt/layers/__pycache__/drop_path.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e9111b69beab87a38dc21e66964b3c128992988
Binary files /dev/null and b/vggt/layers/__pycache__/drop_path.cpython-310.pyc differ
diff --git a/vggt/layers/__pycache__/drop_path.cpython-311.pyc b/vggt/layers/__pycache__/drop_path.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..51e4133db7788592cdf51a9ebeb494481bca0460
Binary files /dev/null and b/vggt/layers/__pycache__/drop_path.cpython-311.pyc differ
diff --git a/vggt/layers/__pycache__/drop_path.cpython-312.pyc b/vggt/layers/__pycache__/drop_path.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8db6ed8659d56003ec6e30e8844cacdee7650183
Binary files /dev/null and b/vggt/layers/__pycache__/drop_path.cpython-312.pyc differ
diff --git a/vggt/layers/__pycache__/layer_scale.cpython-310.pyc b/vggt/layers/__pycache__/layer_scale.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fdfa58ee3b7335bf9fac09afe34c28bb9b565fd6
Binary files /dev/null and b/vggt/layers/__pycache__/layer_scale.cpython-310.pyc differ
diff --git a/vggt/layers/__pycache__/layer_scale.cpython-311.pyc b/vggt/layers/__pycache__/layer_scale.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e8d717d5f486cea02877ada1fc48b80a74cd97da
Binary files /dev/null and b/vggt/layers/__pycache__/layer_scale.cpython-311.pyc differ
diff --git a/vggt/layers/__pycache__/layer_scale.cpython-312.pyc b/vggt/layers/__pycache__/layer_scale.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1336f757c18dec65a66c97fa75018a7f90ea7481
Binary files /dev/null and b/vggt/layers/__pycache__/layer_scale.cpython-312.pyc differ
diff --git a/vggt/layers/__pycache__/mlp.cpython-310.pyc b/vggt/layers/__pycache__/mlp.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fb277a307b61f7eef0dec6804a3e95bfb5f3d1a1
Binary files /dev/null and b/vggt/layers/__pycache__/mlp.cpython-310.pyc differ
diff --git a/vggt/layers/__pycache__/mlp.cpython-311.pyc b/vggt/layers/__pycache__/mlp.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5dfa3abd939bb8ea850b5fcf15589eb2cfc7751
Binary files /dev/null and b/vggt/layers/__pycache__/mlp.cpython-311.pyc differ
diff --git a/vggt/layers/__pycache__/mlp.cpython-312.pyc b/vggt/layers/__pycache__/mlp.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..042f9fad55a45f236884c13168bce4f1d31e16de
Binary files /dev/null and b/vggt/layers/__pycache__/mlp.cpython-312.pyc differ
diff --git a/vggt/layers/__pycache__/patch_embed.cpython-310.pyc b/vggt/layers/__pycache__/patch_embed.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1dce720bf6ada9e1284469fe0611439970471220
Binary files /dev/null and b/vggt/layers/__pycache__/patch_embed.cpython-310.pyc differ
diff --git a/vggt/layers/__pycache__/patch_embed.cpython-311.pyc b/vggt/layers/__pycache__/patch_embed.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..178a450b8d438247ca3e94c47cecce6ab8165074
Binary files /dev/null and b/vggt/layers/__pycache__/patch_embed.cpython-311.pyc differ
diff --git a/vggt/layers/__pycache__/patch_embed.cpython-312.pyc b/vggt/layers/__pycache__/patch_embed.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c8c4863004642f540170ad6cd3ef97cdc79a7a60
Binary files /dev/null and b/vggt/layers/__pycache__/patch_embed.cpython-312.pyc differ
diff --git a/vggt/layers/__pycache__/rope.cpython-310.pyc b/vggt/layers/__pycache__/rope.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6cd015fda482a648a2ab175c42a2feb018ff5f3a
Binary files /dev/null and b/vggt/layers/__pycache__/rope.cpython-310.pyc differ
diff --git a/vggt/layers/__pycache__/rope.cpython-311.pyc b/vggt/layers/__pycache__/rope.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5adb62450925c2c8611ebcc307f2b71fe00fc16f
Binary files /dev/null and b/vggt/layers/__pycache__/rope.cpython-311.pyc differ
diff --git a/vggt/layers/__pycache__/rope.cpython-312.pyc b/vggt/layers/__pycache__/rope.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b08d42d1c80a796a2b9b2d62a4401f20bf22b467
Binary files /dev/null and b/vggt/layers/__pycache__/rope.cpython-312.pyc differ
diff --git a/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc b/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b354cc7e93643dacb2d6e14dd5d35d14da0cbc83
Binary files /dev/null and b/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc differ
diff --git a/vggt/layers/__pycache__/swiglu_ffn.cpython-311.pyc b/vggt/layers/__pycache__/swiglu_ffn.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..854f37a5df453f5600ff1d9582e306c764b6238d
Binary files /dev/null and b/vggt/layers/__pycache__/swiglu_ffn.cpython-311.pyc differ
diff --git a/vggt/layers/__pycache__/swiglu_ffn.cpython-312.pyc b/vggt/layers/__pycache__/swiglu_ffn.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..44534316fd90d45d1e017c6d4211c399897c3de8
Binary files /dev/null and b/vggt/layers/__pycache__/swiglu_ffn.cpython-312.pyc differ
diff --git a/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc b/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..188afddb380e49d554bc2432e8634cce7a123e48
Binary files /dev/null and b/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc differ
diff --git a/vggt/layers/__pycache__/vision_transformer.cpython-311.pyc b/vggt/layers/__pycache__/vision_transformer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..28338bfeb0dd2c309ec434379e704fbc5284b04a
Binary files /dev/null and b/vggt/layers/__pycache__/vision_transformer.cpython-311.pyc differ
diff --git a/vggt/layers/__pycache__/vision_transformer.cpython-312.pyc b/vggt/layers/__pycache__/vision_transformer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c8a84c4671b3b93a38a238b48c459b7bbca94c58
Binary files /dev/null and b/vggt/layers/__pycache__/vision_transformer.cpython-312.pyc differ
diff --git a/vggt/layers/attention.py b/vggt/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..55ce51b5539af99b4254b59599dfd3ccd7bae182
--- /dev/null
+++ b/vggt/layers/attention.py
@@ -0,0 +1,326 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+import os
+import warnings
+
+import torch
+from torch import Tensor
+import torch.nn.functional as F
+from torch import nn
+from typing import Union, Tuple, Dict, Optional
+
+from einops import rearrange
+
+XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = True,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ norm_layer: nn.Module = nn.LayerNorm,
+ qk_norm: bool = False,
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
+ rope=None,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim**-0.5
+ self.fused_attn = fused_attn
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.rope = rope
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ pos=None,
+ attn_mask=None,
+ past_key_values=None,
+ use_cache=False
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple]]:
+ if False:
+ torch.set_printoptions(threshold=float('inf'))
+ torch.set_printoptions(precision=10)
+ log_file = "baseline.log"
+
+ first_dimension = 0
+ second_dimension = 1
+ third_dimension = 0
+ fourth_dimension_start = 10
+ fourth_dimension_end = 20
+
+ if False:
+ with open(log_file, "a") as f:
+ f.write(f"--- Forward ---\n")
+ f.write(f"X shape {x.shape}, dtype: {x.dtype}\n")
+ with torch.no_grad():
+ x_mean = x.float().mean().item()
+ x_max = x.float().max().item()
+ some_x = x[0, :-100, 10]
+ f.write(f" X stats: mean={x_mean:.6f}, max={x_max:.6f}\n")
+ f.write(f" Some X stats: {some_x}\n")
+
+ B, N, C = x.shape
+
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0)
+
+ if False:
+ with open(log_file, "a") as f:
+ f.write("--- Init ---\n")
+ f.write(f"K shape {k.shape}, V shape {v.shape}, Q shape {q.shape}, dtype: {k.dtype}, {v.dtype}, {q.dtype}\n")
+ with torch.no_grad():
+ k_mean = k.float().mean().item()
+ k_max = k.float().max().item()
+ v_mean = v.float().mean().item()
+ v_max = v.float().max().item()
+ q_mean = q.float().mean().item()
+ q_max = q.float().max().item()
+ #some_q = q[first_dimension, second_dimension, third_dimension, fourth_dimension_start:fourth_dimension_end]
+ #some_k = k[first_dimension, second_dimension, third_dimension, fourth_dimension_start:fourth_dimension_end]
+ #some_v = v[first_dimension, second_dimension, third_dimension, fourth_dimension_start:fourth_dimension_end]
+ some_q = q[0, 0, -100:, 0]
+ some_k = k[0, 0, -100:, 0]
+ some_v = v[0, 0, -100:, 0]
+ f.write(f" Q stats: mean={q_mean:.6f}, max={q_max:.6f}\n")
+ f.write(f" Some Q stats: {some_q}\n")
+ f.write(f" K stats: mean={k_mean:.6f}, max={k_max:.6f}\n")
+ f.write(f" Some K stats: {some_k}\n")
+ f.write(f" V stats: mean={v_mean:.6f}, max={v_max:.6f}\n")
+ f.write(f" Some V stats: {some_v}\n")
+ if q.shape[-2] > 1041 and False:
+ last_q = q[:, :, -1041:, :]
+ with torch.no_grad():
+ last_q_mean = last_q.float().mean().item()
+ last_q_max = last_q.float().max().item()
+ some = last_q[first_dimension, second_dimension, third_dimension, fourth_dimension_start:fourth_dimension_end]
+ f.write(f" Last Q stats: mean={last_q_mean:.6f}, max={last_q_max:.6f}\n")
+ f.write(f" Some Last Q stats: {some}\n")
+
+ if k.shape[-2] > 1041 and False:
+ last_k = k[:, :, -1041:, :]
+ with torch.no_grad():
+ last_k_mean = last_k.float().mean().item()
+ last_k_max = last_k.float().max().item()
+ some = last_k[first_dimension, second_dimension, third_dimension, fourth_dimension_start:fourth_dimension_end]
+ f.write(f" Last K stats: mean={last_k_mean:.6f}, max={last_k_max:.6f}\n")
+ f.write(f" Some Last K stats: {some}\n")
+
+ if v.shape[-2] > 1041 and False:
+ last_v = v[:, :, -1041:, :]
+ with torch.no_grad():
+ last_v_mean = last_v.float().mean().item()
+ last_v_max = last_v.float().max().item()
+ some = last_v[first_dimension, second_dimension, third_dimension, fourth_dimension_start:fourth_dimension_end]
+ f.write(f" Last V stats: mean={last_v_mean:.6f}, max={last_v_max:.6f}\n")
+ f.write(f" Some Last V stats: {some}\n")
+ pos_k = pos
+ if use_cache:
+ k = k.unsqueeze(2)
+ v = v.unsqueeze(2)
+ if past_key_values is not None:
+ past_k, past_v = past_key_values
+ k = torch.cat([past_k, k], dim=2)
+ v = torch.cat([past_v, v], dim=2)
+
+ new_kv = (k, v)
+ a, b, c, d, e = k.shape
+ k = k.reshape(a, b, c*d, e)
+ v = v.reshape(a, b, c*d, e)
+ if pos_k is not None:
+ #print(pos_k.shape)
+ pos_k = pos_k.repeat(1, c, 1)
+ #print(pos_k.shape)
+ if False:
+ with open(log_file, "a") as f:
+ f.write("--- After past key values ---\n")
+ f.write(f"K shape {k.shape}, V shape {v.shape}, dtype: {k.dtype}, {v.dtype}\n")
+ with torch.no_grad():
+ k_mean = k.float().mean().item()
+ k_max = k.float().max().item()
+ v_mean = v.float().mean().item()
+ v_max = v.float().max().item()
+ some_k = k[0, 0, -100:, 0]
+ some_v = v[0, 0, -100:, 0]
+ f.write(f" K stats: mean={k_mean:.6f}, max={k_max:.6f}\n")
+ f.write(f" Some K stats: {some_k}\n")
+ f.write(f" V stats: mean={v_mean:.6f}, max={v_max:.6f}\n")
+ f.write(f" Some V stats: {some_v}\n")
+ if k.shape[-2] > 1041 and False:
+ last_k = k[:, :, -1041:, :]
+ with torch.no_grad():
+ last_k_mean = last_k.float().mean().item()
+ last_k_max = last_k.float().max().item()
+ some = last_k[first_dimension, second_dimension, third_dimension, fourth_dimension_start:fourth_dimension_end]
+ f.write(f" Last K stats: mean={last_k_mean:.6f}, max={last_k_max:.6f}\n")
+ f.write(f" Some Last K stats: {some}\n")
+ if v.shape[-2] > 1041 and False:
+ last_v = v[:, :, -1041:, :]
+ with torch.no_grad():
+ last_v_mean = last_v.float().mean().item()
+ last_v_max = last_v.float().max().item()
+ some = last_v[first_dimension, second_dimension, third_dimension, fourth_dimension_start:fourth_dimension_end]
+ f.write(f" Last V stats: mean={last_v_mean:.6f}, max={last_v_max:.6f}\n")
+ f.write(f" Some Last V stats: {some}\n")
+
+ q, k = self.q_norm(q), self.k_norm(k)
+
+ if False:
+ with open(log_file, "a") as f:
+ f.write("--- After norm ---\n")
+ f.write(f"K shape {k.shape}, Q shape {q.shape}, dtype: {k.dtype}, {q.dtype}\n")
+ with torch.no_grad():
+ k_mean = k.float().mean().item()
+ k_max = k.float().max().item()
+ q_mean = q.float().mean().item()
+ q_max = q.float().max().item()
+ some_q = q[0, 0, -100:, 0]
+ some_k = k[0, 0, -100:, 0]
+
+ f.write(f" Q stats: mean={q_mean:.6f}, max={q_max:.6f}\n")
+ f.write(f" Some Q stats: {some_q}\n")
+ f.write(f" K stats: mean={k_mean:.6f}, max={k_max:.6f}\n")
+ f.write(f" Some K stats: {some_k}\n")
+ if q.shape[-2] > 1041 and False:
+ last_q = q[:, :, -1041:, :]
+ with torch.no_grad():
+ last_q_mean = last_q.float().mean().item()
+ last_q_max = last_q.float().max().item()
+ some = last_q[first_dimension, second_dimension, third_dimension, fourth_dimension_start:fourth_dimension_end]
+ f.write(f" Last Q stats: mean={last_q_mean:.6f}, max={last_q_max:.6f}\n")
+ f.write(f" Some Last Q stats: {some}\n")
+
+ if k.shape[-2] > 1041 and False:
+ last_k = k[:, :, -1041:, :]
+ with torch.no_grad():
+ last_k_mean = last_k.float().mean().item()
+ last_k_max = last_k.float().max().item()
+ some = last_k[first_dimension, second_dimension, third_dimension, fourth_dimension_start:fourth_dimension_end]
+ f.write(f" Last K stats: mean={last_k_mean:.6f}, max={last_k_max:.6f}\n")
+ f.write(f" Some Last K stats: {some}\n")
+
+ if self.rope is not None:
+
+ q = self.rope(q, pos)
+ k = self.rope(k, pos_k)
+
+ if False:
+ with open(log_file, "a") as f:
+ f.write("--- After ROPE ---\n")
+ f.write(f"K shape {k.shape}, Q shape {q.shape}, dtype: {k.dtype}, {q.dtype}\n")
+ with torch.no_grad():
+ q_mean = q.float().mean().item()
+ q_max = q.float().max().item()
+ k_mean = k.float().mean().item()
+ k_max = k.float().max().item()
+ some_q = q[0, 0, -100:, 0]
+ some_k = k[0, 0, -100:, 0]
+ f.write(f" Q stats: mean={q_mean:.6f}, max={q_max:.6f}\n")
+ f.write(f" Some Q stats: {some_q}\n")
+ f.write(f" K stats: mean={k_mean:.6f}, max={k_max:.6f}\n")
+ f.write(f" Some K stats: {some_k}\n")
+ if q.shape[-2] > 1041 and False:
+ last_q = q[:, :, -1041:, :]
+ with torch.no_grad():
+ last_q_mean = last_q.float().mean().item()
+ last_q_max = last_q.float().max().item()
+ some = last_q[first_dimension, second_dimension, third_dimension, fourth_dimension_start:fourth_dimension_end]
+ f.write(f" Last Q stats: mean={last_q_mean:.6f}, max={last_q_max:.6f}\n")
+ f.write(f" Some Last Q stats: {some}\n")
+
+ if k.shape[-2] > 1041 and False:
+ last_k = k[:, :, -1041:, :]
+ with torch.no_grad():
+ last_k_mean = last_k.float().mean().item()
+ last_k_max = last_k.float().max().item()
+ some = last_k[first_dimension, second_dimension, third_dimension, fourth_dimension_start:fourth_dimension_end]
+ f.write(f" Last K stats: mean={last_k_mean:.6f}, max={last_k_max:.6f}\n")
+ f.write(f" Some Last K stats: {some}\n")
+
+
+ if self.fused_attn:
+ x = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=attn_mask,
+ dropout_p=self.attn_drop.p if self.training else 0.0,
+ )
+
+ else:
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ # Mask
+ if attn_mask is not None:
+ assert attn_mask.shape[-2:] == (N, N), f"Expected mask shape [..., {N}, {N}], got {attn_mask.shape}"
+ attn = attn + attn_mask
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ if use_cache:
+ return x, new_kv
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(
+ self,
+ x: Tensor,
+ attn_bias=None,
+ pos=None,
+ past_key_values=None,
+ use_cache=False
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict]]:
+
+ assert pos is None
+ if use_cache or not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(
+ x,
+ pos=pos,
+ attn_mask=attn_bias,
+ past_key_values=past_key_values,
+ use_cache=use_cache
+ )
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x
diff --git a/vggt/layers/block.py b/vggt/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..a70ada5771ebdf4ec735320171b3b75cd9fc911a
--- /dev/null
+++ b/vggt/layers/block.py
@@ -0,0 +1,292 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+import os
+from typing import Callable, List, Any, Tuple, Dict, Union
+import warnings
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+XFORMERS_AVAILABLE = False
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ qk_norm: bool = False,
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
+ rope=None,
+ ) -> None:
+ super().__init__()
+
+ self.norm1 = norm_layer(dim)
+
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ qk_norm=qk_norm,
+ fused_attn=fused_attn,
+ rope=rope,
+ )
+
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor, pos=None, attn_mask=None, past_key_values=None, use_cache=False) -> Union[Tensor, Tuple[Tensor, Dict]]:
+
+ def attn_residual_func(x: Tensor, pos=None, attn_mask=None, past_key_values=None, use_cache=False) -> Union[Tensor, Tuple[Tensor, Dict]]:
+ log_file = "baseline.log"
+ torch.set_printoptions(threshold=float('inf'))
+ torch.set_printoptions(precision=10)
+ if False:
+ with open(log_file, "a") as f:
+ f.write(f"X shape {x.shape}, dtype: {x.dtype}\n")
+ with torch.no_grad():
+ x_mean = x.float().mean().item()
+ x_max = x.float().max().item()
+ some_x = x[-1, :-100, 10]
+ f.write(f" X stats: mean={x_mean:.6f}, max={x_max:.6f}\n")
+ f.write(f" Some X stats: {some_x}\n")
+ if use_cache:
+ if attn_mask is not None:
+ output, new_kv = self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask,
+ past_key_values=past_key_values, use_cache=True)
+ return self.ls1(output), new_kv
+ else:
+ output, new_kv = self.attn(self.norm1(x), pos=pos,
+ past_key_values=past_key_values, use_cache=True)
+ return self.ls1(output), new_kv
+ else:
+ if attn_mask is not None:
+ return self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask))
+ else:
+ return self.ls1(self.attn(self.norm1(x), pos=pos))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if use_cache:
+ attn_output, new_kv = attn_residual_func(x, pos=pos, attn_mask=attn_mask,
+ past_key_values=past_key_values, use_cache=True)
+ x = x + attn_output
+ x = x + ffn_residual_func(x)
+ return x, new_kv
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ pos=pos,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos, attn_mask=attn_mask))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x, pos=pos, attn_mask=attn_mask)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+ pos=None,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ if pos is not None:
+ # if necessary, apply rope to the subset
+ pos = pos[brange]
+ residual = residual_func(x_subset, pos=pos)
+ else:
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ if not XFORMERS_AVAILABLE:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/vggt/layers/drop_path.py b/vggt/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5
--- /dev/null
+++ b/vggt/layers/drop_path.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/vggt/layers/layer_scale.py b/vggt/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386
--- /dev/null
+++ b/vggt/layers/layer_scale.py
@@ -0,0 +1,27 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/vggt/layers/mlp.py b/vggt/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e
--- /dev/null
+++ b/vggt/layers/mlp.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/vggt/layers/patch_embed.py b/vggt/layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339
--- /dev/null
+++ b/vggt/layers/patch_embed.py
@@ -0,0 +1,88 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/vggt/layers/rope.py b/vggt/layers/rope.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d5d33304e55dbd05687bd86752a47a80e5f82df
--- /dev/null
+++ b/vggt/layers/rope.py
@@ -0,0 +1,188 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+
+# Implementation of 2D Rotary Position Embeddings (RoPE).
+
+# This module provides a clean implementation of 2D Rotary Position Embeddings,
+# which extends the original RoPE concept to handle 2D spatial positions.
+
+# Inspired by:
+# https://github.com/meta-llama/codellama/blob/main/llama/model.py
+# https://github.com/naver-ai/rope-vit
+
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Dict, Tuple
+
+
+class PositionGetter:
+ """Generates and caches 2D spatial positions for patches in a grid.
+
+ This class efficiently manages the generation of spatial coordinates for patches
+ in a 2D grid, caching results to avoid redundant computations.
+
+ Attributes:
+ position_cache: Dictionary storing precomputed position tensors for different
+ grid dimensions.
+ """
+
+ def __init__(self):
+ """Initializes the position generator with an empty cache."""
+ self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
+
+ def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
+ """Generates spatial positions for a batch of patches.
+
+ Args:
+ batch_size: Number of samples in the batch.
+ height: Height of the grid in patches.
+ width: Width of the grid in patches.
+ device: Target device for the position tensor.
+
+ Returns:
+ Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
+ for each position in the grid, repeated for each batch item.
+ """
+ if (height, width) not in self.position_cache:
+ y_coords = torch.arange(height, device=device)
+ x_coords = torch.arange(width, device=device)
+ positions = torch.cartesian_prod(y_coords, x_coords)
+ self.position_cache[height, width] = positions
+
+ cached_positions = self.position_cache[height, width]
+ return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
+
+
+class RotaryPositionEmbedding2D(nn.Module):
+ """2D Rotary Position Embedding implementation.
+
+ This module applies rotary position embeddings to input tokens based on their
+ 2D spatial positions. It handles the position-dependent rotation of features
+ separately for vertical and horizontal dimensions.
+
+ Args:
+ frequency: Base frequency for the position embeddings. Default: 100.0
+ scaling_factor: Scaling factor for frequency computation. Default: 1.0
+
+ Attributes:
+ base_frequency: Base frequency for computing position embeddings.
+ scaling_factor: Factor to scale the computed frequencies.
+ frequency_cache: Cache for storing precomputed frequency components.
+ """
+
+ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
+ """Initializes the 2D RoPE module."""
+ super().__init__()
+ self.base_frequency = frequency
+ self.scaling_factor = scaling_factor
+ self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
+
+ def _compute_frequency_components(
+ self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Computes frequency components for rotary embeddings.
+
+ Args:
+ dim: Feature dimension (must be even).
+ seq_len: Maximum sequence length.
+ device: Target device for computations.
+ dtype: Data type for the computed tensors.
+
+ Returns:
+ Tuple of (cosine, sine) tensors for frequency components.
+ """
+ cache_key = (dim, seq_len, device, dtype)
+ if cache_key not in self.frequency_cache:
+ # Compute frequency bands
+ exponents = torch.arange(0, dim, 2, device=device).float() / dim
+ inv_freq = 1.0 / (self.base_frequency**exponents)
+
+ # Generate position-dependent frequencies
+ positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
+ angles = torch.einsum("i,j->ij", positions, inv_freq)
+
+ # Compute and cache frequency components
+ angles = angles.to(dtype)
+ angles = torch.cat((angles, angles), dim=-1)
+ cos_components = angles.cos().to(dtype)
+ sin_components = angles.sin().to(dtype)
+ self.frequency_cache[cache_key] = (cos_components, sin_components)
+
+ return self.frequency_cache[cache_key]
+
+ @staticmethod
+ def _rotate_features(x: torch.Tensor) -> torch.Tensor:
+ """Performs feature rotation by splitting and recombining feature dimensions.
+
+ Args:
+ x: Input tensor to rotate.
+
+ Returns:
+ Rotated feature tensor.
+ """
+ feature_dim = x.shape[-1]
+ x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def _apply_1d_rope(
+ self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
+ ) -> torch.Tensor:
+ """Applies 1D rotary position embeddings along one dimension.
+
+ Args:
+ tokens: Input token features.
+ positions: Position indices.
+ cos_comp: Cosine components for rotation.
+ sin_comp: Sine components for rotation.
+
+ Returns:
+ Tokens with applied rotary position embeddings.
+ """
+ # Embed positions with frequency components
+ cos = F.embedding(positions, cos_comp)[:, None, :, :]
+ sin = F.embedding(positions, sin_comp)[:, None, :, :]
+
+ # Apply rotation
+ return (tokens * cos) + (self._rotate_features(tokens) * sin)
+
+ def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
+ """Applies 2D rotary position embeddings to input tokens.
+
+ Args:
+ tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
+ The feature dimension (dim) must be divisible by 4.
+ positions: Position tensor of shape (batch_size, n_tokens, 2) containing
+ the y and x coordinates for each token.
+
+ Returns:
+ Tensor of same shape as input with applied 2D rotary position embeddings.
+
+ Raises:
+ AssertionError: If input dimensions are invalid or positions are malformed.
+ """
+ # Validate inputs
+ assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
+ assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
+
+ # Compute feature dimension for each spatial direction
+ feature_dim = tokens.size(-1) // 2
+
+ # Get frequency components
+ max_position = int(positions.max()) + 1
+ cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
+
+ # Split features for vertical and horizontal processing
+ vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
+
+ # Apply RoPE separately for each dimension
+ vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
+ horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
+
+ # Combine processed features
+ return torch.cat((vertical_features, horizontal_features), dim=-1)
diff --git a/vggt/layers/swiglu_ffn.py b/vggt/layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..54fe8e90b7bedf6fbdbf09c6215844e3cc63f857
--- /dev/null
+++ b/vggt/layers/swiglu_ffn.py
@@ -0,0 +1,72 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import os
+from typing import Callable, Optional
+import warnings
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+# try:
+# if XFORMERS_ENABLED:
+# from xformers.ops import SwiGLU
+
+# XFORMERS_AVAILABLE = True
+# warnings.warn("xFormers is available (SwiGLU)")
+# else:
+# warnings.warn("xFormers is disabled (SwiGLU)")
+# raise ImportError
+# except ImportError:
+SwiGLU = SwiGLUFFN
+XFORMERS_AVAILABLE = False
+
+# warnings.warn("xFormers is not available (SwiGLU)")
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/vggt/layers/vision_transformer.py b/vggt/layers/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..120cbe6c26650d212e50aefc497669abdc937467
--- /dev/null
+++ b/vggt/layers/vision_transformer.py
@@ -0,0 +1,407 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+from torch.nn.init import trunc_normal_
+from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+logger = logging.getLogger("dinov2")
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ qk_norm=False,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ # tricky but makes it work
+ self.use_checkpoint = False
+ #
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
+ assert N == M * M
+ kwargs = {}
+ if self.interpolate_offset:
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
+ sx = float(w0 + self.interpolate_offset) / M
+ sy = float(h0 + self.interpolate_offset) / M
+ kwargs["scale_factor"] = (sx, sy)
+ else:
+ # Simply specify an output size instead of a scale factor
+ kwargs["size"] = (w0, h0)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ **kwargs,
+ )
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
+ else:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
+ else:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=True, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
diff --git a/vggt/models/__pycache__/aggregator.cpython-310.pyc b/vggt/models/__pycache__/aggregator.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..94aa0c3d492761931179b5407f040badc2ff59f2
Binary files /dev/null and b/vggt/models/__pycache__/aggregator.cpython-310.pyc differ
diff --git a/vggt/models/__pycache__/aggregator.cpython-311.pyc b/vggt/models/__pycache__/aggregator.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..64d568ae4764997558c93f23b5a3bb622bdc91ef
Binary files /dev/null and b/vggt/models/__pycache__/aggregator.cpython-311.pyc differ
diff --git a/vggt/models/__pycache__/aggregator.cpython-312.pyc b/vggt/models/__pycache__/aggregator.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b16906bd7e7872184ede1be52c7bc45bce9a10f
Binary files /dev/null and b/vggt/models/__pycache__/aggregator.cpython-312.pyc differ
diff --git a/vggt/models/__pycache__/vggt.cpython-310.pyc b/vggt/models/__pycache__/vggt.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e494953a01c63a708f2a3f792c87818dc53af3e
Binary files /dev/null and b/vggt/models/__pycache__/vggt.cpython-310.pyc differ
diff --git a/vggt/models/__pycache__/vggt.cpython-311.pyc b/vggt/models/__pycache__/vggt.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cc9c9da41aa8541d55bce3695c415f5a76edde0c
Binary files /dev/null and b/vggt/models/__pycache__/vggt.cpython-311.pyc differ
diff --git a/vggt/models/__pycache__/vggt.cpython-312.pyc b/vggt/models/__pycache__/vggt.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8805daefae1703af26efb3f0d33758ae5106ede6
Binary files /dev/null and b/vggt/models/__pycache__/vggt.cpython-312.pyc differ
diff --git a/vggt/models/aggregator.py b/vggt/models/aggregator.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5e8b4875630caa73f4fc88302c5f1e5e340a2d0
--- /dev/null
+++ b/vggt/models/aggregator.py
@@ -0,0 +1,566 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Optional, Tuple, Union, List, Dict, Any
+
+from vggt.layers import PatchEmbed
+from vggt.layers.block import Block
+from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
+from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
+
+logger = logging.getLogger(__name__)
+
+_RESNET_MEAN = [0.485, 0.456, 0.406]
+_RESNET_STD = [0.229, 0.224, 0.225]
+
+
+class Aggregator(nn.Module):
+ """
+ The Aggregator applies alternating-attention over input frames,
+ as described in VGGT: Visual Geometry Grounded Transformer.
+
+
+ Args:
+ img_size (int): Image size in pixels.
+ patch_size (int): Size of each patch for PatchEmbed.
+ embed_dim (int): Dimension of the token embeddings.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
+ num_register_tokens (int): Number of register tokens.
+ block_fn (nn.Module): The block type used for attention (Block by default).
+ qkv_bias (bool): Whether to include bias in QKV projections.
+ proj_bias (bool): Whether to include bias in the output projection.
+ ffn_bias (bool): Whether to include bias in MLP layers.
+ patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
+ aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
+ aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
+ qk_norm (bool): Whether to apply QK normalization.
+ rope_freq (int): Base frequency for rotary embedding. -1 to disable.
+ init_values (float): Init scale for layer scale.
+ """
+
+ def __init__(
+ self,
+ img_size=518,
+ patch_size=14,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4.0,
+ num_register_tokens=4,
+ block_fn=Block,
+ qkv_bias=True,
+ proj_bias=True,
+ ffn_bias=True,
+ patch_embed="dinov2_vitl14_reg",
+ aa_order=["frame", "global"],
+ aa_block_size=1,
+ qk_norm=True,
+ rope_freq=100,
+ init_values=0.01,
+ use_causal_global=True,
+ ):
+ super().__init__()
+
+ self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
+
+ # Initialize rotary position embedding if frequency > 0
+ self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
+ self.position_getter = PositionGetter() if self.rope is not None else None
+
+ self.frame_blocks = nn.ModuleList(
+ [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ rope=self.rope,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ self.global_blocks = nn.ModuleList(
+ [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ rope=self.rope,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ self.depth = depth
+ self.aa_order = aa_order
+ self.patch_size = patch_size
+ self.aa_block_size = aa_block_size
+
+ # Validate that depth is divisible by aa_block_size
+ if self.depth % self.aa_block_size != 0:
+ raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
+
+ self.aa_block_num = self.depth // self.aa_block_size
+
+ # Note: We have two camera tokens, one for the first frame and one for the rest
+ # The same applies for register tokens
+ self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
+ self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
+
+ # The patch tokens start after the camera and register tokens
+ self.patch_start_idx = 1 + num_register_tokens
+
+ # Initialize parameters with small values
+ nn.init.normal_(self.camera_token, std=1e-6)
+ nn.init.normal_(self.register_token, std=1e-6)
+
+ # Register normalization constants as buffers
+ for name, value in (
+ ("_resnet_mean", _RESNET_MEAN),
+ ("_resnet_std", _RESNET_STD),
+ ):
+ self.register_buffer(
+ name,
+ torch.FloatTensor(value).reshape(1, 1, 3, 1, 1),
+ persistent=False,
+ )
+
+ self.use_causal_global = use_causal_global
+
+ def __build_patch_embed__(
+ self,
+ patch_embed,
+ img_size,
+ patch_size,
+ num_register_tokens,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ block_chunks=0,
+ init_values=1.0,
+ embed_dim=1024,
+ ):
+ """
+ Build the patch embed layer. If 'conv', we use a
+ simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
+ """
+
+ if "conv" in patch_embed:
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
+ else:
+ vit_models = {
+ "dinov2_vitl14_reg": vit_large,
+ "dinov2_vitb14_reg": vit_base,
+ "dinov2_vits14_reg": vit_small,
+ "dinov2_vitg2_reg": vit_giant2,
+ }
+
+ self.patch_embed = vit_models[patch_embed](
+ img_size=img_size,
+ patch_size=patch_size,
+ num_register_tokens=num_register_tokens,
+ interpolate_antialias=interpolate_antialias,
+ interpolate_offset=interpolate_offset,
+ block_chunks=block_chunks,
+ init_values=init_values,
+ )
+
+ # Disable gradient updates for mask token
+ if hasattr(self.patch_embed, "mask_token"):
+ self.patch_embed.mask_token.requires_grad_(False)
+
+ def forward(
+ self,
+ images: torch.Tensor,
+ past_key_values=None,
+ use_cache=False,
+ past_frame_idx=0
+ ) -> Union[Tuple[List[torch.Tensor], int], Tuple[List[torch.Tensor], int, Dict]]:
+ """
+ Args:
+ images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
+
+ Returns:
+ (list[torch.Tensor], int):
+ The list of outputs from the attention blocks,
+ and the patch_start_idx indicating where patch tokens begin.
+ """
+ B, S, C_in, H, W = images.shape
+ if use_cache and past_key_values[0] is not None:
+ _, _, S_true, _, _ = past_key_values[0][0].shape
+ S_true += 1
+ else:
+ S_true = S
+
+ if False:
+ import os
+ from datetime import datetime
+ first_dimension = 0
+ second_dimension = 1
+ third_dimension = 0
+ fourth_dimension_start = 10
+ fourth_dimension_end = 20
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ log_file = f"baseline.log"
+
+ if use_cache and S > 1:
+ print(f"Use KV cache expects S=1, got S={S}")
+
+ if False:
+ with open(log_file, "a") as f:
+ if use_cache and past_key_values is not None:
+ f.write(f"=== Past KV cache initial state ===\n")
+ for i, kv in enumerate(past_key_values):
+ if kv is not None:
+ k, v = kv
+ f.write(f"Block {i}: K shape {k.shape}, V shape {v.shape}, dtype: {k.dtype}, {v.dtype}\n")
+ with torch.no_grad():
+ k_mean = k.float().mean().item()
+ k_max = k.float().max().item()
+ v_mean = v.float().mean().item()
+ v_max = v.float().max().item()
+ f.write(f" K stats: mean={k_mean:.6f}, max={k_max:.6f}\n")
+ f.write(f" V stats: mean={v_mean:.6f}, max={v_max:.6f}\n")
+
+ if k.shape[-2] >= 1041 and False:
+ last_k = k[:, :, -1041:, :]
+ with torch.no_grad():
+ last_k_mean = last_k.float().mean().item()
+ last_k_max = last_k.float().max().item()
+ some = last_k[first_dimension, second_dimension, third_dimension, fourth_dimension_start:fourth_dimension_end]
+ f.write(f" Last K stats: mean={last_k_mean:.6f}, max={last_k_max:.6f}\n")
+ f.write(f" Some Last K stats: {some}\n")
+
+ if v.shape[-2] >= 1041 and False:
+ last_v = v[:, :, -1041:, :]
+ with torch.no_grad():
+ last_v_mean = last_v.float().mean().item()
+ last_v_max = last_v.float().max().item()
+ some = last_v[first_dimension, second_dimension, third_dimension, fourth_dimension_start:fourth_dimension_end]
+ f.write(f" Last V stats: mean={last_v_mean:.6f}, max={last_v_max:.6f}\n")
+ f.write(f" Some Last V stats: {some}\n")
+ else:
+ f.write(f"Block {i}: None\n")
+
+ # Normalize images and reshape for patch embed
+ images = (images - self._resnet_mean.to(images.device)) / self._resnet_std.to(images.device)
+
+ # Reshape to [B*S, C, H, W] for patch embedding
+ images = images.reshape(B * S, C_in, H, W)
+ patch_tokens = self.patch_embed(images)
+
+
+
+ if isinstance(patch_tokens, dict):
+ patch_tokens = patch_tokens["x_norm_patchtokens"]
+
+ if False:
+ x = patch_tokens
+ with open(log_file, "a") as f:
+ f.write(f"=== Patch tokens initial state ===\n")
+ f.write(f"X shape {x.shape}, dtype: {x.dtype}\n")
+ with torch.no_grad():
+ x_mean = x.float().mean().item()
+ x_max = x.float().max().item()
+ some_x = x[-1, :-100, 10]
+ f.write(f" X stats: mean={x_mean:.6f}, max={x_max:.6f}\n")
+ f.write(f" Some X stats: {some_x}\n")
+ f.write(f"=== Patch tokens after embedding ===\n")
+
+ _, P, C = patch_tokens.shape
+
+ if use_cache:
+ # Expand camera and register tokens to match batch size and sequence length
+ camera_token_full = slice_expand_and_flatten(self.camera_token, B, S_true)
+ camera_token = camera_token_full[-1:, :, :]
+
+ register_token_full = slice_expand_and_flatten(self.register_token, B, S_true)
+ register_token = register_token_full[-1:, :, :]
+ else:
+ camera_token = slice_expand_and_flatten(self.camera_token, B, S)
+ register_token = slice_expand_and_flatten(self.register_token, B, S)
+ if False:
+ with open(log_file, "a") as f:
+ f.write(f"=== Camera tokens initial state ===\n")
+ f.write(f"Camera token shape {camera_token.shape}, dtype: {camera_token.dtype}\n")
+ f.write(f"{camera_token}\n")
+ if False:
+ with open(log_file, "a") as f:
+ f.write(f"=== Register tokens initial state ===\n")
+ f.write(f"Register token shape {register_token.shape}, dtype: {register_token.dtype}\n")
+ f.write(f"{register_token}\n")
+
+ # Concatenate special tokens with patch tokens
+ tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)
+
+ pos = None
+ if self.rope is not None:
+ pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
+
+ if self.patch_start_idx > 0:
+ # do not use position embedding for special tokens (camera and register tokens)
+ # so set pos to 0 for the special tokens
+ pos = pos + 1
+ pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
+ pos = torch.cat([pos_special, pos], dim=1)
+
+ # update P because we added special tokens
+ _, P, C = tokens.shape
+
+ # 在处理区块前添加日志
+ frame_idx = 0
+ global_idx = 0
+ output_list = []
+
+ for block_num in range(self.aa_block_num):
+ if False:
+ with open(log_file, "a") as f:
+ f.write(f"=== Processing AA block {block_num}/{self.aa_block_num} ===\n")
+ for attn_type in self.aa_order:
+ if attn_type == "frame":
+ tokens, frame_idx, frame_intermediates = self._process_frame_attention(
+ tokens, B, S, P, C, frame_idx, pos=pos
+ )
+ if False:
+ with open(log_file, "a") as f:
+ f.write(f"Frame attention completed: frame_idx = {frame_idx}\n")
+ elif attn_type == "global":
+ if use_cache:
+ if False:
+ with open(log_file, "a") as f:
+ f.write(f"Global attention with KV cache: global_idx = {global_idx}\n")
+ if past_key_values[global_idx] is not None:
+ k, v = past_key_values[global_idx]
+ if False:
+ with open(log_file, "a") as f:
+ f.write(f" Using cached KV at idx {global_idx}: K shape {k.shape}, V shape {v.shape}, dtype: {k.dtype}, {v.dtype}\n")
+
+ tokens, global_idx, global_intermediates, new_kv = self._process_global_attention(
+ tokens, B, S, P, C, global_idx, pos=pos,
+ past_key_values_block=past_key_values[global_idx] if past_key_values[global_idx] is not None else None,
+ use_cache=True,
+ past_frame_idx=past_frame_idx
+ )
+ if False:
+ with open(log_file, "a") as f:
+ # 打印更新后的 KV 缓存
+ if new_kv is not None:
+ k, v = new_kv
+ f.write(f" New KV cache: K shape {k.shape}, V shape {v.shape}, dtype: {k.dtype}, {v.dtype}\n")
+ f.write(f" Updating block at idx {global_idx-1}\n")
+ with torch.no_grad():
+ k_mean = k.float().mean().item()
+ k_max = k.float().max().item()
+ v_mean = v.float().mean().item()
+ v_max = v.float().max().item()
+ f.write(f" K stats: mean={k_mean:.6f}, max={k_max:.6f}\n")
+ f.write(f" V stats: mean={v_mean:.6f}, max={v_max:.6f}\n")
+
+ if k.shape[-2] >= 1041:
+ last_k = k[:, :, -1041:, :]
+ with torch.no_grad():
+ last_k_mean = last_k.float().mean().item()
+ last_k_max = last_k.float().max().item()
+ some = last_k[first_dimension, second_dimension, third_dimension, fourth_dimension_start:fourth_dimension_end]
+ f.write(f" Last K stats: mean={last_k_mean:.6f}, max={last_k_max:.6f}\n")
+ #f.write(f" Some Last K stats: {some}\n")
+
+ if v.shape[-2] >= 1041:
+ last_v = v[:, :, -1041:, :]
+ with torch.no_grad():
+ last_v_mean = last_v.float().mean().item()
+ last_v_max = last_v.float().max().item()
+ some = last_v[first_dimension, second_dimension, third_dimension, fourth_dimension_start:fourth_dimension_end]
+ f.write(f" Last V stats: mean={last_v_mean:.6f}, max={last_v_max:.6f}\n")
+ #f.write(f" Some Last V stats: {some}\n")
+
+ past_key_values[global_idx - 1] = new_kv
+ else:
+ tokens, global_idx, global_intermediates = self._process_global_attention(
+ tokens, B, S, P, C, global_idx, pos=pos
+ )
+ if False:
+ with open(log_file, "a") as f:
+ f.write(f"Global attention without KV cache: global_idx = {global_idx}\n")
+ else:
+ raise ValueError(f"Unknown attention type: {attn_type}")
+
+ for i in range(len(frame_intermediates)):
+ # concat frame and global intermediates, [B x S x P x 2C]
+ concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
+ output_list.append(concat_inter)
+
+ del concat_inter
+ del frame_intermediates
+ del global_intermediates
+
+ if False:
+ with open(log_file, "a") as f:
+ # 在返回前添加最终状态日志
+ if use_cache:
+ f.write(f"=== Final KV cache state ===\n")
+ for i, kv in enumerate(past_key_values):
+ if kv is not None:
+ k, v = kv
+ f.write(f"Block {i}: K shape {k.shape}, V shape {v.shape}, dtype: {k.dtype}, {v.dtype}\n")
+ with torch.no_grad():
+ k_mean = k.float().mean().item()
+ k_max = k.float().max().item()
+ v_mean = v.float().mean().item()
+ v_max = v.float().max().item()
+ f.write(f" K stats: mean={k_mean:.6f}, max={k_max:.6f}\n")
+ f.write(f" V stats: mean={v_mean:.6f}, max={v_max:.6f}\n")
+
+ if k.shape[-2] >= 1041 and False:
+ last_k = k[:, :, -1041:, :]
+ with torch.no_grad():
+ last_k_mean = last_k.float().mean().item()
+ last_k_max = last_k.float().max().item()
+ some = last_k[first_dimension, second_dimension, third_dimension, fourth_dimension_start:fourth_dimension_end]
+ f.write(f" Last K stats: mean={last_k_mean:.6f}, max={last_k_max:.6f}\n")
+ f.write(f" Some Last K stats: {some}\n")
+
+ if v.shape[-2] >= 1041 and False:
+ last_v = v[:, :, -1041:, :]
+ with torch.no_grad():
+ last_v_mean = last_v.float().mean().item()
+ last_v_max = last_v.float().max().item()
+ some = last_v[first_dimension, second_dimension, third_dimension, fourth_dimension_start:fourth_dimension_end]
+ f.write(f" Last V stats: mean={last_v_mean:.6f}, max={last_v_max:.6f}\n")
+ f.write(f" Some Last V stats: {some}\n")
+ else:
+ f.write(f"Block {i}: None\n")
+
+ if use_cache:
+ return output_list, self.patch_start_idx, past_key_values
+ return output_list, self.patch_start_idx
+
+
+ def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
+ """
+ Process frame attention blocks. We keep tokens in shape (B*S, P, C).
+ """
+ # If needed, reshape tokens or positions:
+ if tokens.shape != (B * S, P, C):
+ tokens = tokens.reshape(B, S, P, C).reshape(B * S, P, C)
+
+ if pos is not None and pos.shape != (B * S, P, 2):
+ pos = pos.reshape(B, S, P, 2).reshape(B * S, P, 2)
+
+ intermediates = []
+
+ # by default, self.aa_block_size=1, which processes one block at a time
+ for _ in range(self.aa_block_size):
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
+ frame_idx += 1
+ intermediates.append(tokens.reshape(B, S, P, C))
+
+ return tokens, frame_idx, intermediates
+
+
+ def _process_global_attention(
+ self,
+ tokens,
+ B,
+ S,
+ P,
+ C,
+ global_idx,
+ pos=None,
+ past_key_values_block=None,
+ use_cache=False,
+ past_frame_idx=0
+ ) -> Union[Tuple[torch.Tensor, int, List[torch.Tensor]], Tuple[torch.Tensor, int, List[torch.Tensor], List]]:
+ """
+ Process global attention blocks. We keep tokens in shape (B, S*P, C).
+ """
+ #torch.cuda.synchronize()
+ #start_mem = torch.cuda.memory_allocated() / 1024 / 1024 # MB
+ #print(f"Before _process_global_attention: {start_mem:.2f} MB")
+
+ if tokens.shape != (B, S * P, C):
+ tokens = tokens.reshape(B, S, P, C).reshape(B, S * P, C)
+
+ if pos is not None and pos.shape != (B, S * P, 2):
+ pos = pos.reshape(B, S, P, 2).reshape(B, S * P, 2)
+
+ intermediates = []
+
+ for _ in range(self.aa_block_size):
+
+ if self.use_causal_global and not use_cache:
+ L = S * P
+ frame_ids = torch.arange(L, device=tokens.device) // P # [0,0,...,1,1,...,S-1]
+ future_frame = frame_ids.unsqueeze(1) < frame_ids.unsqueeze(0)
+ attn_mask = future_frame.to(tokens.dtype) * torch.finfo(tokens.dtype).min
+ else:
+ attn_mask = None
+
+ if use_cache:
+ tokens, block_kv = self.global_blocks[global_idx](
+ tokens,
+ pos=pos,
+ attn_mask=attn_mask,
+ past_key_values=past_key_values_block,
+ use_cache=True
+ )
+ else:
+ tokens = self.global_blocks[global_idx](tokens, pos=pos, attn_mask=attn_mask)
+
+ global_idx += 1
+ intermediates.append(tokens.reshape(B, S, P, C))
+
+ # if self.use_causal_global:
+ # del attn_mask
+
+ if use_cache:
+ return tokens, global_idx, intermediates, block_kv
+ else:
+ return tokens, global_idx, intermediates
+
+
+
+
+def slice_expand_and_flatten(token_tensor, B, S):
+ """
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
+ 1) Uses the first position (index=0) for the first frame only
+ 2) Uses the second position (index=1) for all remaining frames (S-1 frames)
+ 3) Expands both to match batch size B
+ 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
+ followed by (S-1) second-position tokens
+ 5) Flattens to (B*S, X, C) for processing
+
+ Returns:
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
+ """
+
+ # Slice out the "query" tokens => shape (1, 1, ...)
+ query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
+ # Slice out the "other" tokens => shape (1, S-1, ...)
+ others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
+ # Concatenate => shape (B, S, ...)
+ combined = torch.cat([query, others], dim=1)
+
+ # Finally flatten => shape (B*S, ...)
+ combined = combined.reshape(B * S, *combined.shape[2:])
+ return combined
diff --git a/vggt/models/vggt.py b/vggt/models/vggt.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7bfa31fc87142c7f0f7720cde9cbec6db0aedd8
--- /dev/null
+++ b/vggt/models/vggt.py
@@ -0,0 +1,99 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from huggingface_hub import PyTorchModelHubMixin # used for model hub
+
+from vggt.models.aggregator import Aggregator
+from vggt.heads.camera_head import CameraHead
+from vggt.heads.dpt_head import DPTHead
+from vggt.heads.track_head import TrackHead
+
+
+class VGGT(nn.Module, PyTorchModelHubMixin):
+ def __init__(self, img_size=518, patch_size=14, embed_dim=1024, use_causal_global=True, use_distil=False):
+ super().__init__()
+
+ self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim, use_causal_global=use_causal_global)
+ self.camera_head = CameraHead(dim_in=2 * embed_dim)
+ self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
+ self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1")
+ self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)
+ self.use_causal_global = use_causal_global
+ self.use_distil = use_distil
+
+ def forward(
+ self,
+ images: torch.Tensor,
+ query_points: torch.Tensor = None,
+ ):
+ """
+ Forward pass of the VGGT model.
+
+ Args:
+ images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
+ query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
+ Shape: [N, 2] or [B, N, 2], where N is the number of query points.
+ Default: None
+
+ Returns:
+ dict: A dictionary containing the following predictions:
+ - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)
+ - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]
+ - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]
+ - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]
+ - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]
+ - images (torch.Tensor): Original input images, preserved for visualization
+
+ If query_points is provided, also includes:
+ - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates
+ - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]
+ - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
+ """
+
+ # If without batch dimension, add it
+ if len(images.shape) == 4:
+ images = images.unsqueeze(0)
+ if query_points is not None and len(query_points.shape) == 2:
+ query_points = query_points.unsqueeze(0)
+
+ aggregated_tokens_list, patch_start_idx = self.aggregator(images)
+
+ predictions = {}
+
+ with torch.cuda.amp.autocast(enabled=False):
+ if self.camera_head is not None:
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
+ predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
+
+
+ if self.depth_head is not None:
+ depth, depth_conf = self.depth_head(
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
+ )
+ predictions["depth"] = depth
+ predictions["depth_conf"] = depth_conf
+
+ if self.point_head is not None:
+ pts3d, pts3d_conf = self.point_head(
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
+ )
+ predictions["world_points"] = pts3d
+ predictions["world_points_conf"] = pts3d_conf
+
+ if self.track_head is not None and query_points is not None:
+ track_list, vis, conf = self.track_head(
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points
+ )
+ predictions["track"] = track_list[-1] # track of the last iteration
+ predictions["vis"] = vis
+ predictions["conf"] = conf
+
+ predictions["images"] = images
+
+ return predictions
diff --git a/vggt/utils/__pycache__/geometry.cpython-310.pyc b/vggt/utils/__pycache__/geometry.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d2d93ce51882e026cb62645e5745e71d52096275
Binary files /dev/null and b/vggt/utils/__pycache__/geometry.cpython-310.pyc differ
diff --git a/vggt/utils/__pycache__/geometry.cpython-311.pyc b/vggt/utils/__pycache__/geometry.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eec2dadca23648a1fa8f31d4e34f3654a12e4b2b
Binary files /dev/null and b/vggt/utils/__pycache__/geometry.cpython-311.pyc differ
diff --git a/vggt/utils/__pycache__/geometry.cpython-312.pyc b/vggt/utils/__pycache__/geometry.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e1ef14697f726259cc923ae2552e8551bde980a
Binary files /dev/null and b/vggt/utils/__pycache__/geometry.cpython-312.pyc differ
diff --git a/vggt/utils/__pycache__/load_fn.cpython-310.pyc b/vggt/utils/__pycache__/load_fn.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea3ed152a1a219081a9dde2133aed9081943abe3
Binary files /dev/null and b/vggt/utils/__pycache__/load_fn.cpython-310.pyc differ
diff --git a/vggt/utils/__pycache__/load_fn.cpython-311.pyc b/vggt/utils/__pycache__/load_fn.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15dec3e01cba36ca4b56e174a502b734cd1b1fd5
Binary files /dev/null and b/vggt/utils/__pycache__/load_fn.cpython-311.pyc differ
diff --git a/vggt/utils/__pycache__/load_fn.cpython-312.pyc b/vggt/utils/__pycache__/load_fn.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d10598b69f2fef4c7a9329c388e141932a02b342
Binary files /dev/null and b/vggt/utils/__pycache__/load_fn.cpython-312.pyc differ
diff --git a/vggt/utils/__pycache__/pose_enc.cpython-310.pyc b/vggt/utils/__pycache__/pose_enc.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aa1b72512089483ce434a727f9ff7d69ab42faa1
Binary files /dev/null and b/vggt/utils/__pycache__/pose_enc.cpython-310.pyc differ
diff --git a/vggt/utils/__pycache__/pose_enc.cpython-311.pyc b/vggt/utils/__pycache__/pose_enc.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d7b414b544ecab11493014ff77248f3691994b6
Binary files /dev/null and b/vggt/utils/__pycache__/pose_enc.cpython-311.pyc differ
diff --git a/vggt/utils/__pycache__/pose_enc.cpython-312.pyc b/vggt/utils/__pycache__/pose_enc.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8f56b1abd75896ad462781f44676bb2f21b9aee4
Binary files /dev/null and b/vggt/utils/__pycache__/pose_enc.cpython-312.pyc differ
diff --git a/vggt/utils/__pycache__/rotation.cpython-310.pyc b/vggt/utils/__pycache__/rotation.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07bddb1e0ca864f013cb6417b03311a21c864f58
Binary files /dev/null and b/vggt/utils/__pycache__/rotation.cpython-310.pyc differ
diff --git a/vggt/utils/__pycache__/rotation.cpython-311.pyc b/vggt/utils/__pycache__/rotation.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..210d627dded44b5ca60706d22ad5858f838e9e58
Binary files /dev/null and b/vggt/utils/__pycache__/rotation.cpython-311.pyc differ
diff --git a/vggt/utils/__pycache__/rotation.cpython-312.pyc b/vggt/utils/__pycache__/rotation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..58886e8dee887d5945fa9db7721dce6b34f11b46
Binary files /dev/null and b/vggt/utils/__pycache__/rotation.cpython-312.pyc differ
diff --git a/vggt/utils/__pycache__/visual_track.cpython-310.pyc b/vggt/utils/__pycache__/visual_track.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4bc57330046eb42b71375ebc4cbae8ae4b7997de
Binary files /dev/null and b/vggt/utils/__pycache__/visual_track.cpython-310.pyc differ
diff --git a/vggt/utils/__pycache__/visual_track.cpython-311.pyc b/vggt/utils/__pycache__/visual_track.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aed936b2af371fc97055b7f36a89e2fcbad69b58
Binary files /dev/null and b/vggt/utils/__pycache__/visual_track.cpython-311.pyc differ
diff --git a/vggt/utils/__pycache__/visual_track.cpython-312.pyc b/vggt/utils/__pycache__/visual_track.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4e0739095f4c147b0e48b3e3d214f1078c05718
Binary files /dev/null and b/vggt/utils/__pycache__/visual_track.cpython-312.pyc differ
diff --git a/vggt/utils/geometry.py b/vggt/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ebd25dbc6cac6b0095956524c4f0628410dd5cb
--- /dev/null
+++ b/vggt/utils/geometry.py
@@ -0,0 +1,166 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import torch
+import numpy as np
+
+
+def unproject_depth_map_to_point_map(
+ depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
+) -> np.ndarray:
+ """
+ Unproject a batch of depth maps to 3D world coordinates.
+
+ Args:
+ depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
+ extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
+ intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
+
+ Returns:
+ np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
+ """
+ if isinstance(depth_map, torch.Tensor):
+ depth_map = depth_map.cpu().numpy()
+ if isinstance(extrinsics_cam, torch.Tensor):
+ extrinsics_cam = extrinsics_cam.cpu().numpy()
+ if isinstance(intrinsics_cam, torch.Tensor):
+ intrinsics_cam = intrinsics_cam.cpu().numpy()
+
+ world_points_list = []
+ for frame_idx in range(depth_map.shape[0]):
+ cur_world_points, _, _ = depth_to_world_coords_points(
+ depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
+ )
+ world_points_list.append(cur_world_points)
+ world_points_array = np.stack(world_points_list, axis=0)
+
+ return world_points_array
+
+
+def depth_to_world_coords_points(
+ depth_map: np.ndarray,
+ extrinsic: np.ndarray,
+ intrinsic: np.ndarray,
+ eps=1e-8,
+) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Convert a depth map to world coordinates.
+
+ Args:
+ depth_map (np.ndarray): Depth map of shape (H, W).
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
+ extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
+ """
+ if depth_map is None:
+ return None, None, None
+
+ # Valid depth mask
+ point_mask = depth_map > eps
+
+ # Convert depth map to camera coordinates
+ cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
+
+ # Multiply with the inverse of extrinsic matrix to transform to world coordinates
+ # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
+ cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
+
+ R_cam_to_world = cam_to_world_extrinsic[:3, :3]
+ t_cam_to_world = cam_to_world_extrinsic[:3, 3]
+
+ # Apply the rotation and translation to the camera coordinates
+ world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
+ # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
+
+ return world_coords_points, cam_coords_points, point_mask
+
+
+def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
+ """
+ Convert a depth map to camera coordinates.
+
+ Args:
+ depth_map (np.ndarray): Depth map of shape (H, W).
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
+ """
+ H, W = depth_map.shape
+ assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
+ assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
+
+ # Intrinsic parameters
+ fu, fv = intrinsic[0, 0], intrinsic[1, 1]
+ cu, cv = intrinsic[0, 2], intrinsic[1, 2]
+
+ # Generate grid of pixel coordinates
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
+
+ # Unproject to camera coordinates
+ x_cam = (u - cu) * depth_map / fu
+ y_cam = (v - cv) * depth_map / fv
+ z_cam = depth_map
+
+ # Stack to form camera coordinates
+ cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
+
+ return cam_coords
+
+
+def closed_form_inverse_se3(se3, R=None, T=None):
+ """
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
+
+ If `R` and `T` are provided, they must correspond to the rotation and translation
+ components of `se3`. Otherwise, they will be extracted from `se3`.
+
+ Args:
+ se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
+ R (optional): Nx3x3 array or tensor of rotation matrices.
+ T (optional): Nx3x1 array or tensor of translation vectors.
+
+ Returns:
+ Inverted SE3 matrices with the same type and device as `se3`.
+
+ Shapes:
+ se3: (N, 4, 4)
+ R: (N, 3, 3)
+ T: (N, 3, 1)
+ """
+ # Check if se3 is a numpy array or a torch tensor
+ is_numpy = isinstance(se3, np.ndarray)
+
+ # Validate shapes
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
+
+ # Extract R and T if not provided
+ if R is None:
+ R = se3[:, :3, :3] # (N,3,3)
+ if T is None:
+ T = se3[:, :3, 3:] # (N,3,1)
+
+ # Transpose R
+ if is_numpy:
+ # Compute the transpose of the rotation for NumPy
+ R_transposed = np.transpose(R, (0, 2, 1))
+ # -R^T t for NumPy
+ top_right = -np.matmul(R_transposed, T)
+ inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
+ else:
+ R_transposed = R.transpose(1, 2) # (N,3,3)
+ top_right = -torch.bmm(R_transposed, T) # (N,3,1)
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
+
+ inverted_matrix[:, :3, :3] = R_transposed
+ inverted_matrix[:, :3, 3:] = top_right
+
+ return inverted_matrix
diff --git a/vggt/utils/load_fn.py b/vggt/utils/load_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..14408aea5a30e9fa392d4569bae66869dc8a52ab
--- /dev/null
+++ b/vggt/utils/load_fn.py
@@ -0,0 +1,142 @@
+import torch
+import torch.nn.functional as F
+from typing import Optional, Union, Tuple, Dict, List
+
+from PIL import Image
+from torchvision import transforms as TF
+
+def load_and_preprocess_images(image_path_list, mode="crop"):
+ """
+ A quick start function to load and preprocess images for model input.
+ This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
+
+ Args:
+ image_path_list (list): List of paths to image files
+ mode (str, optional): Preprocessing mode, either "crop" or "pad".
+ - "crop" (default): Sets width to 518px and center crops height if needed.
+ - "pad": Preserves all pixels by making the largest dimension 518px
+ and padding the smaller dimension to reach a square shape.
+
+ Returns:
+ torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
+
+ Raises:
+ ValueError: If the input list is empty or if mode is invalid
+
+ Notes:
+ - Images with different dimensions will be padded with white (value=1.0)
+ - A warning is printed when images have different shapes
+ - When mode="crop": The function ensures width=518px while maintaining aspect ratio
+ and height is center-cropped if larger than 518px
+ - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
+ and the smaller dimension is padded to reach a square shape (518x518)
+ - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
+ """
+ # Check for empty list
+ if len(image_path_list) == 0:
+ raise ValueError("At least 1 image is required")
+
+ # Validate mode
+ if mode not in ["crop", "pad"]:
+ raise ValueError("Mode must be either 'crop' or 'pad'")
+
+ images = []
+ shapes = set()
+ to_tensor = TF.ToTensor()
+ target_size = 518
+
+ # First process all images and collect their shapes
+ for image_path in image_path_list:
+
+ # Open image
+ img = Image.open(image_path)
+
+ # If there's an alpha channel, blend onto white background:
+ if img.mode == "RGBA":
+ # Create white background
+ background = Image.new("RGBA", img.size, (255, 255, 255, 255))
+ # Alpha composite onto the white background
+ img = Image.alpha_composite(background, img)
+
+ # Now convert to "RGB" (this step assigns white for transparent areas)
+ img = img.convert("RGB")
+
+ width, height = img.size
+
+ if mode == "pad":
+ # Make the largest dimension 518px while maintaining aspect ratio
+ if width >= height:
+ new_width = target_size
+ new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14
+ else:
+ new_height = target_size
+ new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14
+ else: # mode == "crop"
+ # Original behavior: set width to 518px
+ new_width = target_size
+ # Calculate height maintaining aspect ratio, divisible by 14
+ new_height = round(height * (new_width / width) / 14) * 14
+
+ # Resize with new dimensions (width, height)
+ img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
+ img = to_tensor(img) # Convert to tensor (0, 1)
+
+ # Center crop height if it's larger than 518 (only in crop mode)
+ if mode == "crop" and new_height > target_size:
+ start_y = (new_height - target_size) // 2
+ img = img[:, start_y : start_y + target_size, :]
+
+ # For pad mode, pad to make a square of target_size x target_size
+ if mode == "pad":
+ h_padding = target_size - img.shape[1]
+ w_padding = target_size - img.shape[2]
+
+ if h_padding > 0 or w_padding > 0:
+ pad_top = h_padding // 2
+ pad_bottom = h_padding - pad_top
+ pad_left = w_padding // 2
+ pad_right = w_padding - pad_left
+
+ # Pad with white (value=1.0)
+ img = torch.nn.functional.pad(
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
+ )
+
+ shapes.add((img.shape[1], img.shape[2]))
+ images.append(img)
+
+ # Check if we have different shapes
+ # In theory our model can also work well with different shapes
+ if len(shapes) > 1:
+ print(f"Warning: Found images with different shapes: {shapes}")
+ # Find maximum dimensions
+ max_height = max(shape[0] for shape in shapes)
+ max_width = max(shape[1] for shape in shapes)
+
+ # Pad images if necessary
+ padded_images = []
+ for img in images:
+ h_padding = max_height - img.shape[1]
+ w_padding = max_width - img.shape[2]
+
+ if h_padding > 0 or w_padding > 0:
+ pad_top = h_padding // 2
+ pad_bottom = h_padding - pad_top
+ pad_left = w_padding // 2
+ pad_right = w_padding - pad_left
+
+ img = torch.nn.functional.pad(
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
+ )
+ padded_images.append(img)
+ images = padded_images
+
+ images = torch.stack(images) # concatenate images
+
+ # Ensure correct shape when single image
+ if len(image_path_list) == 1:
+ # Verify shape is (1, C, H, W)
+ if images.dim() == 3:
+ images = images.unsqueeze(0)
+
+ return images
diff --git a/vggt/utils/pose_enc.py b/vggt/utils/pose_enc.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f98b0878cb13451b8cdb80074349cbf2644c5fa
--- /dev/null
+++ b/vggt/utils/pose_enc.py
@@ -0,0 +1,130 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from .rotation import quat_to_mat, mat_to_quat
+
+
+def extri_intri_to_pose_encoding(
+ extrinsics,
+ intrinsics,
+ image_size_hw=None, # e.g., (256, 512)
+ pose_encoding_type="absT_quaR_FoV",
+):
+ """Convert camera extrinsics and intrinsics to a compact pose encoding.
+
+ This function transforms camera parameters into a unified pose encoding format,
+ which can be used for various downstream tasks like pose prediction or representation.
+
+ Args:
+ extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
+ where B is batch size and S is sequence length.
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
+ The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
+ intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
+ Defined in pixels, with format:
+ [[fx, 0, cx],
+ [0, fy, cy],
+ [0, 0, 1]]
+ where fx, fy are focal lengths and (cx, cy) is the principal point
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
+ Required for computing field of view values. For example: (256, 512).
+ pose_encoding_type (str): Type of pose encoding to use. Currently only
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
+
+ Returns:
+ torch.Tensor: Encoded camera pose parameters with shape BxSx9.
+ For "absT_quaR_FoV" type, the 9 dimensions are:
+ - [:3] = absolute translation vector T (3D)
+ - [3:7] = rotation as quaternion quat (4D)
+ - [7:] = field of view (2D)
+ """
+
+ # extrinsics: BxSx3x4
+ # intrinsics: BxSx3x3
+
+ if pose_encoding_type == "absT_quaR_FoV":
+ R = extrinsics[:, :, :3, :3] # BxSx3x3
+ T = extrinsics[:, :, :3, 3] # BxSx3
+
+ quat = mat_to_quat(R)
+ # Note the order of h and w here
+ H, W = image_size_hw
+ fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
+ fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
+ pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
+ else:
+ raise NotImplementedError
+
+ return pose_encoding
+
+
+def pose_encoding_to_extri_intri(
+ pose_encoding,
+ image_size_hw=None, # e.g., (256, 512)
+ pose_encoding_type="absT_quaR_FoV",
+ build_intrinsics=True,
+):
+ """Convert a pose encoding back to camera extrinsics and intrinsics.
+
+ This function performs the inverse operation of extri_intri_to_pose_encoding,
+ reconstructing the full camera parameters from the compact encoding.
+
+ Args:
+ pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
+ where B is batch size and S is sequence length.
+ For "absT_quaR_FoV" type, the 9 dimensions are:
+ - [:3] = absolute translation vector T (3D)
+ - [3:7] = rotation as quaternion quat (4D)
+ - [7:] = field of view (2D)
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
+ Required for reconstructing intrinsics from field of view values.
+ For example: (256, 512).
+ pose_encoding_type (str): Type of pose encoding used. Currently only
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
+ build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
+ If False, only extrinsics are returned and intrinsics will be None.
+
+ Returns:
+ tuple: (extrinsics, intrinsics)
+ - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
+ transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
+ a 3x1 translation vector.
+ - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
+ or None if build_intrinsics is False. Defined in pixels, with format:
+ [[fx, 0, cx],
+ [0, fy, cy],
+ [0, 0, 1]]
+ where fx, fy are focal lengths and (cx, cy) is the principal point,
+ assumed to be at the center of the image (W/2, H/2).
+ """
+
+ intrinsics = None
+
+ if pose_encoding_type == "absT_quaR_FoV":
+ T = pose_encoding[..., :3]
+ quat = pose_encoding[..., 3:7]
+ fov_h = pose_encoding[..., 7]
+ fov_w = pose_encoding[..., 8]
+
+ R = quat_to_mat(quat)
+ extrinsics = torch.cat([R, T[..., None]], dim=-1)
+
+ if build_intrinsics:
+ H, W = image_size_hw
+ fy = (H / 2.0) / torch.tan(fov_h / 2.0)
+ fx = (W / 2.0) / torch.tan(fov_w / 2.0)
+ intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
+ intrinsics[..., 0, 0] = fx
+ intrinsics[..., 1, 1] = fy
+ intrinsics[..., 0, 2] = W / 2
+ intrinsics[..., 1, 2] = H / 2
+ intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
+ else:
+ raise NotImplementedError
+
+ return extrinsics, intrinsics
diff --git a/vggt/utils/rotation.py b/vggt/utils/rotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..657583e6915437c824c192d51939990b589a14fa
--- /dev/null
+++ b/vggt/utils/rotation.py
@@ -0,0 +1,138 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
+
+import torch
+import numpy as np
+import torch.nn.functional as F
+
+
+def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Quaternion Order: XYZW or say ijkr, scalar-last
+
+ Convert rotations given as quaternions to rotation matrices.
+ Args:
+ quaternions: quaternions with real part last,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ i, j, k, r = torch.unbind(quaternions, -1)
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+
+def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part last, as tensor of shape (..., 4).
+ Quaternion Order: XYZW or say ijkr, scalar-last
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [
+ 1.0 + m00 + m11 + m22,
+ 1.0 + m00 - m11 - m22,
+ 1.0 - m00 + m11 - m22,
+ 1.0 - m00 - m11 + m22,
+ ],
+ dim=-1,
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+ out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
+
+ # Convert from rijk to ijkr
+ out = out[..., [1, 2, 3, 0]]
+
+ out = standardize_quaternion(out)
+
+ return out
+
+
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ """
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ if torch.is_grad_enabled():
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ else:
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
+ return ret
+
+
+def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert a unit quaternion to a standard form: one in which the real
+ part is non negative.
+
+ Args:
+ quaternions: Quaternions with real part last,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Standardized quaternions as tensor of shape (..., 4).
+ """
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
diff --git a/vggt/utils/visual_track.py b/vggt/utils/visual_track.py
new file mode 100644
index 0000000000000000000000000000000000000000..796c114ccba00b5f7850e04b9444a6cd5c44b154
--- /dev/null
+++ b/vggt/utils/visual_track.py
@@ -0,0 +1,239 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import cv2
+import torch
+import numpy as np
+import os
+
+
+def color_from_xy(x, y, W, H, cmap_name="hsv"):
+ """
+ Map (x, y) -> color in (R, G, B).
+ 1) Normalize x,y to [0,1].
+ 2) Combine them into a single scalar c in [0,1].
+ 3) Use matplotlib's colormap to convert c -> (R,G,B).
+
+ You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).
+ """
+ import matplotlib.cm
+ import matplotlib.colors
+
+ x_norm = x / max(W - 1, 1)
+ y_norm = y / max(H - 1, 1)
+ # Simple combination:
+ c = (x_norm + y_norm) / 2.0
+
+ cmap = matplotlib.cm.get_cmap(cmap_name)
+ # cmap(c) -> (r,g,b,a) in [0,1]
+ rgba = cmap(c)
+ r, g, b = rgba[0], rgba[1], rgba[2]
+ return (r, g, b) # in [0,1], RGB order
+
+
+def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"):
+ """
+ Given all tracks in one sample (b), compute a (N,3) array of RGB color values
+ in [0,255]. The color is determined by the (x,y) position in the first
+ visible frame for each track.
+
+ Args:
+ tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.
+ vis_mask_b: (S, N) boolean mask; if None, assume all are visible.
+ image_width, image_height: used for normalizing (x, y).
+ cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').
+
+ Returns:
+ track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].
+ """
+ S, N, _ = tracks_b.shape
+ track_colors = np.zeros((N, 3), dtype=np.uint8)
+
+ if vis_mask_b is None:
+ # treat all as visible
+ vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)
+
+ for i in range(N):
+ # Find first visible frame for track i
+ visible_frames = torch.where(vis_mask_b[:, i])[0]
+ if len(visible_frames) == 0:
+ # track is never visible; just assign black or something
+ track_colors[i] = (0, 0, 0)
+ continue
+
+ first_s = int(visible_frames[0].item())
+ # use that frame's (x,y)
+ x, y = tracks_b[first_s, i].tolist()
+
+ # map (x,y) -> (R,G,B) in [0,1]
+ r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name)
+ # scale to [0,255]
+ r, g, b = int(r * 255), int(g * 255), int(b * 255)
+ track_colors[i] = (r, g, b)
+
+ return track_colors
+
+
+def visualize_tracks_on_images(
+ images,
+ tracks,
+ track_vis_mask=None,
+ out_dir="track_visuals_concat_by_xy",
+ image_format="CHW", # "CHW" or "HWC"
+ normalize_mode="[0,1]",
+ cmap_name="hsv", # e.g. "hsv", "rainbow", "jet"
+ frames_per_row=4, # New parameter for grid layout
+ save_grid=True, # Flag to control whether to save the grid image
+):
+ """
+ Visualizes frames in a grid layout with specified frames per row.
+ Each track's color is determined by its (x,y) position
+ in the first visible frame (or frame 0 if always visible).
+ Finally convert the BGR result to RGB before saving.
+ Also saves each individual frame as a separate PNG file.
+
+ Args:
+ images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC.
+ tracks: torch.Tensor (S, N, 2), last dim = (x, y).
+ track_vis_mask: torch.Tensor (S, N) or None.
+ out_dir: folder to save visualizations.
+ image_format: "CHW" or "HWC".
+ normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255
+ cmap_name: a matplotlib colormap name for color_from_xy.
+ frames_per_row: number of frames to display in each row of the grid.
+ save_grid: whether to save all frames in one grid image.
+
+ Returns:
+ None (saves images in out_dir).
+ """
+
+ if len(tracks.shape) == 4:
+ tracks = tracks.squeeze(0)
+ images = images.squeeze(0)
+ if track_vis_mask is not None:
+ track_vis_mask = track_vis_mask.squeeze(0)
+
+ import matplotlib
+
+ matplotlib.use("Agg") # for non-interactive (optional)
+
+ os.makedirs(out_dir, exist_ok=True)
+
+ S = images.shape[0]
+ _, N, _ = tracks.shape # (S, N, 2)
+
+ # Move to CPU
+ images = images.cpu().clone()
+ tracks = tracks.cpu().clone()
+ if track_vis_mask is not None:
+ track_vis_mask = track_vis_mask.cpu().clone()
+
+ # Infer H, W from images shape
+ if image_format == "CHW":
+ # e.g. images[s].shape = (3, H, W)
+ H, W = images.shape[2], images.shape[3]
+ else:
+ # e.g. images[s].shape = (H, W, 3)
+ H, W = images.shape[1], images.shape[2]
+
+ # Pre-compute the color for each track i based on first visible position
+ track_colors_rgb = get_track_colors_by_position(
+ tracks, # shape (S, N, 2)
+ vis_mask_b=track_vis_mask if track_vis_mask is not None else None,
+ image_width=W,
+ image_height=H,
+ cmap_name=cmap_name,
+ )
+
+ # We'll accumulate each frame's drawn image in a list
+ frame_images = []
+
+ for s in range(S):
+ # shape => either (3, H, W) or (H, W, 3)
+ img = images[s]
+
+ # Convert to (H, W, 3)
+ if image_format == "CHW":
+ img = img.permute(1, 2, 0) # (H, W, 3)
+ # else "HWC", do nothing
+
+ img = img.numpy().astype(np.float32)
+
+ # Scale to [0,255] if needed
+ if normalize_mode == "[0,1]":
+ img = np.clip(img, 0, 1) * 255.0
+ elif normalize_mode == "[-1,1]":
+ img = (img + 1.0) * 0.5 * 255.0
+ img = np.clip(img, 0, 255.0)
+ # else no normalization
+
+ # Convert to uint8
+ img = img.astype(np.uint8)
+
+ # For drawing in OpenCV, convert to BGR
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+
+ # Draw each visible track
+ cur_tracks = tracks[s] # shape (N, 2)
+ if track_vis_mask is not None:
+ valid_indices = torch.where(track_vis_mask[s])[0]
+ else:
+ valid_indices = range(N)
+
+ cur_tracks_np = cur_tracks.numpy()
+ for i in valid_indices:
+ x, y = cur_tracks_np[i]
+ pt = (int(round(x)), int(round(y)))
+
+ # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR
+ R, G, B = track_colors_rgb[i]
+ color_bgr = (int(B), int(G), int(R))
+ cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)
+
+ # Convert back to RGB for consistent final saving:
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
+
+ # Save individual frame
+ frame_path = os.path.join(out_dir, f"frame_{s:04d}.png")
+ # Convert to BGR for OpenCV imwrite
+ frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
+ cv2.imwrite(frame_path, frame_bgr)
+
+ frame_images.append(img_rgb)
+
+ # Only create and save the grid image if save_grid is True
+ if save_grid:
+ # Calculate grid dimensions
+ num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division
+
+ # Create a grid of images
+ grid_img = None
+ for row in range(num_rows):
+ start_idx = row * frames_per_row
+ end_idx = min(start_idx + frames_per_row, S)
+
+ # Concatenate this row horizontally
+ row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1)
+
+ # If this row has fewer than frames_per_row images, pad with black
+ if end_idx - start_idx < frames_per_row:
+ padding_width = (frames_per_row - (end_idx - start_idx)) * W
+ padding = np.zeros((H, padding_width, 3), dtype=np.uint8)
+ row_img = np.concatenate([row_img, padding], axis=1)
+
+ # Add this row to the grid
+ if grid_img is None:
+ grid_img = row_img
+ else:
+ grid_img = np.concatenate([grid_img, row_img], axis=0)
+
+ out_path = os.path.join(out_dir, "tracks_grid.png")
+ # Convert back to BGR for OpenCV imwrite
+ grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR)
+ cv2.imwrite(out_path, grid_img_bgr)
+ print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}")
+
+ print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")
diff --git a/vggt_to_colmap.py b/vggt_to_colmap.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c7a193dc247768e0a8645109eaa849bccce7378
--- /dev/null
+++ b/vggt_to_colmap.py
@@ -0,0 +1,593 @@
+import os
+import argparse
+import numpy as np
+import torch
+import glob
+import struct
+from scipy.spatial.transform import Rotation
+import sys
+from PIL import Image
+import cv2
+import requests
+import tempfile
+
+from vggt.models.vggt import VGGT
+from vggt.utils.load_fn import load_and_preprocess_images
+from vggt.utils.pose_enc import pose_encoding_to_extri_intri
+from vggt.utils.geometry import unproject_depth_map_to_point_map
+
+def load_model(device=None):
+ """Load and initialize the VGGT model."""
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ print(f"Using device: {device}")
+
+ model = VGGT.from_pretrained("facebook/VGGT-1B")
+
+ # model = VGGT()
+ # _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
+ # model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
+
+ model.eval()
+ model = model.to(device)
+ return model, device
+
+def process_images(image_dir, model, device):
+ """Process images with VGGT and return predictions."""
+ image_names = glob.glob(os.path.join(image_dir, "*"))
+ image_names = sorted([f for f in image_names if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
+ print(f"Found {len(image_names)} images")
+
+ if len(image_names) == 0:
+ raise ValueError(f"No images found in {image_dir}")
+
+ original_images = []
+ for img_path in image_names:
+ img = Image.open(img_path).convert('RGB')
+ original_images.append(np.array(img))
+
+ images = load_and_preprocess_images(image_names).to(device)
+ print(f"Preprocessed images shape: {images.shape}")
+
+ print("Running inference...")
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
+
+ with torch.no_grad():
+ with torch.cuda.amp.autocast(dtype=dtype):
+ predictions = model(images)
+
+ print("Converting pose encoding to camera parameters...")
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
+ predictions["extrinsic"] = extrinsic
+ predictions["intrinsic"] = intrinsic
+
+ for key in predictions.keys():
+ if isinstance(predictions[key], torch.Tensor):
+ predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension
+
+ print("Computing 3D points from depth maps...")
+ depth_map = predictions["depth"] # (S, H, W, 1)
+ world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"])
+ predictions["world_points_from_depth"] = world_points
+
+ predictions["original_images"] = original_images
+
+ S, H, W = world_points.shape[:3]
+ normalized_images = np.zeros((S, H, W, 3), dtype=np.float32)
+
+ for i, img in enumerate(original_images):
+ resized_img = cv2.resize(img, (W, H))
+ normalized_images[i] = resized_img / 255.0
+
+ predictions["images"] = normalized_images
+
+ return predictions, image_names
+
+def extrinsic_to_colmap_format(extrinsics):
+ """Convert extrinsic matrices to COLMAP format (quaternion + translation)."""
+ num_cameras = extrinsics.shape[0]
+ quaternions = []
+ translations = []
+
+ for i in range(num_cameras):
+ # VGGT's extrinsic is camera-to-world (R|t) format
+ R = extrinsics[i, :3, :3]
+ t = extrinsics[i, :3, 3]
+
+ # Convert rotation matrix to quaternion
+ # COLMAP quaternion format is [qw, qx, qy, qz]
+ rot = Rotation.from_matrix(R)
+ quat = rot.as_quat() # scipy returns [x, y, z, w]
+ quat = np.array([quat[3], quat[0], quat[1], quat[2]]) # Convert to [w, x, y, z]
+
+ quaternions.append(quat)
+ translations.append(t)
+
+ return np.array(quaternions), np.array(translations)
+
+def download_file_from_url(url, filename):
+ """Downloads a file from a URL, handling redirects."""
+ try:
+ response = requests.get(url, allow_redirects=False)
+ response.raise_for_status()
+
+ if response.status_code == 302:
+ redirect_url = response.headers["Location"]
+ response = requests.get(redirect_url, stream=True)
+ response.raise_for_status()
+ else:
+ response = requests.get(url, stream=True)
+ response.raise_for_status()
+
+ with open(filename, "wb") as f:
+ for chunk in response.iter_content(chunk_size=8192):
+ f.write(chunk)
+ print(f"Downloaded {filename} successfully.")
+ return True
+
+ except requests.exceptions.RequestException as e:
+ print(f"Error downloading file: {e}")
+ return False
+
+def segment_sky(image_path, onnx_session, mask_filename=None):
+ """
+ Segments sky from an image using an ONNX model.
+ """
+ image = cv2.imread(image_path)
+
+ result_map = run_skyseg(onnx_session, [320, 320], image)
+ result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
+
+ # Fix: Invert the mask so that 255 = non-sky, 0 = sky
+ # The model outputs low values for sky, high values for non-sky
+ output_mask = np.zeros_like(result_map_original)
+ output_mask[result_map_original < 32] = 255 # Use threshold of 32
+
+ if mask_filename is not None:
+ os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
+ cv2.imwrite(mask_filename, output_mask)
+
+ return output_mask
+
+def run_skyseg(onnx_session, input_size, image):
+ """
+ Runs sky segmentation inference using ONNX model.
+ """
+ import copy
+
+ temp_image = copy.deepcopy(image)
+ resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
+ x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
+ x = np.array(x, dtype=np.float32)
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
+ x = (x / 255 - mean) / std
+ x = x.transpose(2, 0, 1)
+ x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
+
+ input_name = onnx_session.get_inputs()[0].name
+ output_name = onnx_session.get_outputs()[0].name
+ onnx_result = onnx_session.run([output_name], {input_name: x})
+
+ onnx_result = np.array(onnx_result).squeeze()
+ min_value = np.min(onnx_result)
+ max_value = np.max(onnx_result)
+ onnx_result = (onnx_result - min_value) / (max_value - min_value)
+ onnx_result *= 255
+ onnx_result = onnx_result.astype("uint8")
+
+ return onnx_result
+
+def filter_and_prepare_points(predictions, conf_threshold, mask_sky=False, mask_black_bg=False,
+ mask_white_bg=False, stride=1, prediction_mode="Depthmap and Camera Branch"):
+ """
+ Filter points based on confidence and prepare for COLMAP format.
+ Implementation matches the conventions in the original VGGT code.
+ """
+
+ if "Pointmap" in prediction_mode:
+ print("Using Pointmap Branch")
+ if "world_points" in predictions:
+ pred_world_points = predictions["world_points"]
+ pred_world_points_conf = predictions.get("world_points_conf", np.ones_like(pred_world_points[..., 0]))
+ else:
+ print("Warning: world_points not found in predictions, falling back to depth-based points")
+ pred_world_points = predictions["world_points_from_depth"]
+ pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0]))
+ else:
+ print("Using Depthmap and Camera Branch")
+ pred_world_points = predictions["world_points_from_depth"]
+ pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0]))
+
+ colors_rgb = predictions["images"]
+
+ S, H, W = pred_world_points.shape[:3]
+ if colors_rgb.shape[:3] != (S, H, W):
+ print(f"Reshaping colors_rgb from {colors_rgb.shape} to match {(S, H, W, 3)}")
+ reshaped_colors = np.zeros((S, H, W, 3), dtype=np.float32)
+ for i in range(S):
+ if i < len(colors_rgb):
+ reshaped_colors[i] = cv2.resize(colors_rgb[i], (W, H))
+ colors_rgb = reshaped_colors
+
+ colors_rgb = (colors_rgb * 255).astype(np.uint8)
+
+ if mask_sky:
+ print("Applying sky segmentation mask")
+ try:
+ import onnxruntime
+
+ with tempfile.TemporaryDirectory() as temp_dir:
+ print(f"Created temporary directory for sky segmentation: {temp_dir}")
+ temp_images_dir = os.path.join(temp_dir, "images")
+ sky_masks_dir = os.path.join(temp_dir, "sky_masks")
+ os.makedirs(temp_images_dir, exist_ok=True)
+ os.makedirs(sky_masks_dir, exist_ok=True)
+
+ image_list = []
+ for i, img in enumerate(colors_rgb):
+ img_path = os.path.join(temp_images_dir, f"image_{i:04d}.png")
+ image_list.append(img_path)
+ cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
+
+
+ skyseg_path = os.path.join(temp_dir, "skyseg.onnx")
+ if not os.path.exists("skyseg.onnx"):
+ print("Downloading skyseg.onnx...")
+ download_success = download_file_from_url(
+ "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx",
+ skyseg_path
+ )
+ if not download_success:
+ print("Failed to download skyseg model, skipping sky filtering")
+ mask_sky = False
+ else:
+
+ import shutil
+ shutil.copy("skyseg.onnx", skyseg_path)
+
+ if mask_sky:
+ skyseg_session = onnxruntime.InferenceSession(skyseg_path)
+ sky_mask_list = []
+
+ for img_path in image_list:
+ mask_path = os.path.join(sky_masks_dir, os.path.basename(img_path))
+ sky_mask = segment_sky(img_path, skyseg_session, mask_path)
+
+ if sky_mask.shape[0] != H or sky_mask.shape[1] != W:
+ sky_mask = cv2.resize(sky_mask, (W, H))
+
+ sky_mask_list.append(sky_mask)
+
+ sky_mask_array = np.array(sky_mask_list)
+
+ sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32)
+ pred_world_points_conf = pred_world_points_conf * sky_mask_binary
+ print(f"Applied sky mask, shape: {sky_mask_binary.shape}")
+
+ except (ImportError, Exception) as e:
+ print(f"Error in sky segmentation: {e}")
+ mask_sky = False
+
+ vertices_3d = pred_world_points.reshape(-1, 3)
+ conf = pred_world_points_conf.reshape(-1)
+ colors_rgb_flat = colors_rgb.reshape(-1, 3)
+
+
+
+ if len(conf) != len(colors_rgb_flat):
+ print(f"WARNING: Shape mismatch between confidence ({len(conf)}) and colors ({len(colors_rgb_flat)})")
+ min_size = min(len(conf), len(colors_rgb_flat))
+ conf = conf[:min_size]
+ vertices_3d = vertices_3d[:min_size]
+ colors_rgb_flat = colors_rgb_flat[:min_size]
+
+ if conf_threshold == 0.0:
+ conf_thres_value = 0.0
+ else:
+ conf_thres_value = np.percentile(conf, conf_threshold)
+
+ print(f"Using confidence threshold: {conf_threshold}% (value: {conf_thres_value:.4f})")
+ conf_mask = (conf >= conf_thres_value) & (conf > 1e-5)
+
+ if mask_black_bg:
+ print("Filtering black background")
+ black_bg_mask = colors_rgb_flat.sum(axis=1) >= 16
+ conf_mask = conf_mask & black_bg_mask
+
+ if mask_white_bg:
+ print("Filtering white background")
+ white_bg_mask = ~((colors_rgb_flat[:, 0] > 240) & (colors_rgb_flat[:, 1] > 240) & (colors_rgb_flat[:, 2] > 240))
+ conf_mask = conf_mask & white_bg_mask
+
+ filtered_vertices = vertices_3d[conf_mask]
+ filtered_colors = colors_rgb_flat[conf_mask]
+
+ if len(filtered_vertices) == 0:
+ print("Warning: No points remaining after filtering. Using default point.")
+ filtered_vertices = np.array([[0, 0, 0]])
+ filtered_colors = np.array([[200, 200, 200]])
+
+ print(f"Filtered to {len(filtered_vertices)} points")
+
+ points3D = []
+ point_indices = {}
+ image_points2D = [[] for _ in range(len(pred_world_points))]
+
+ print(f"Preparing points for COLMAP format with stride {stride}...")
+
+ total_points = 0
+ for img_idx in range(S):
+ for y in range(0, H, stride):
+ for x in range(0, W, stride):
+ flat_idx = img_idx * H * W + y * W + x
+
+ if flat_idx >= len(conf):
+ continue
+
+ if conf[flat_idx] < conf_thres_value or conf[flat_idx] <= 1e-5:
+ continue
+
+ if mask_black_bg and colors_rgb_flat[flat_idx].sum() < 16:
+ continue
+
+ if mask_white_bg and all(colors_rgb_flat[flat_idx] > 240):
+ continue
+
+ point3D = vertices_3d[flat_idx]
+ rgb = colors_rgb_flat[flat_idx]
+
+ if not np.all(np.isfinite(point3D)):
+ continue
+
+ point_hash = hash_point(point3D, scale=100)
+
+ if point_hash not in point_indices:
+ point_idx = len(points3D)
+ point_indices[point_hash] = point_idx
+
+ point_entry = {
+ "id": point_idx,
+ "xyz": point3D,
+ "rgb": rgb,
+ "error": 1.0,
+ "track": [(img_idx, len(image_points2D[img_idx]))]
+ }
+ points3D.append(point_entry)
+ total_points += 1
+ else:
+ point_idx = point_indices[point_hash]
+ points3D[point_idx]["track"].append((img_idx, len(image_points2D[img_idx])))
+
+ image_points2D[img_idx].append((x, y, point_indices[point_hash]))
+
+ print(f"Prepared {len(points3D)} 3D points with {sum(len(pts) for pts in image_points2D)} observations for COLMAP")
+ return points3D, image_points2D
+
+def hash_point(point, scale=100):
+ """Create a hash for a 3D point by quantizing coordinates."""
+ quantized = tuple(np.round(point * scale).astype(int))
+ return hash(quantized)
+
+def write_colmap_cameras_txt(file_path, intrinsics, image_width, image_height):
+ """Write camera intrinsics to COLMAP cameras.txt format."""
+ with open(file_path, 'w') as f:
+ f.write("# Camera list with one line of data per camera:\n")
+ f.write("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n")
+ f.write(f"# Number of cameras: {len(intrinsics)}\n")
+
+ for i, intrinsic in enumerate(intrinsics):
+ camera_id = i + 1 # COLMAP uses 1-indexed camera IDs
+ model = "PINHOLE"
+
+ fx = intrinsic[0, 0]
+ fy = intrinsic[1, 1]
+ cx = intrinsic[0, 2]
+ cy = intrinsic[1, 2]
+
+ f.write(f"{camera_id} {model} {image_width} {image_height} {fx} {fy} {cx} {cy}\n")
+
+def write_colmap_images_txt(file_path, quaternions, translations, image_points2D, image_names):
+ """Write camera poses and keypoints to COLMAP images.txt format."""
+ with open(file_path, 'w') as f:
+ f.write("# Image list with two lines of data per image:\n")
+ f.write("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n")
+ f.write("# POINTS2D[] as (X, Y, POINT3D_ID)\n")
+
+ num_points = sum(len(points) for points in image_points2D)
+ avg_points = num_points / len(image_points2D) if image_points2D else 0
+ f.write(f"# Number of images: {len(quaternions)}, mean observations per image: {avg_points:.1f}\n")
+
+ for i in range(len(quaternions)):
+ image_id = i + 1
+ camera_id = i + 1
+
+ qw, qx, qy, qz = quaternions[i]
+ tx, ty, tz = translations[i]
+
+ f.write(f"{image_id} {qw} {qx} {qy} {qz} {tx} {ty} {tz} {camera_id} {os.path.basename(image_names[i])}\n")
+
+ points_line = " ".join([f"{x} {y} {point3d_id+1}" for x, y, point3d_id in image_points2D[i]])
+ f.write(f"{points_line}\n")
+
+def write_colmap_points3D_txt(file_path, points3D):
+ """Write 3D points and tracks to COLMAP points3D.txt format."""
+ with open(file_path, 'w') as f:
+ f.write("# 3D point list with one line of data per point:\n")
+ f.write("# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n")
+
+ avg_track_length = sum(len(point["track"]) for point in points3D) / len(points3D) if points3D else 0
+ f.write(f"# Number of points: {len(points3D)}, mean track length: {avg_track_length:.4f}\n")
+
+ for point in points3D:
+ point_id = point["id"] + 1
+ x, y, z = point["xyz"]
+ r, g, b = point["rgb"]
+ error = point["error"]
+
+ track = " ".join([f"{img_id+1} {point2d_idx}" for img_id, point2d_idx in point["track"]])
+
+ f.write(f"{point_id} {x} {y} {z} {int(r)} {int(g)} {int(b)} {error} {track}\n")
+
+def write_colmap_cameras_bin(file_path, intrinsics, image_width, image_height):
+ """Write camera intrinsics to COLMAP cameras.bin format."""
+ with open(file_path, 'wb') as fid:
+ # Write number of cameras (uint64)
+ fid.write(struct.pack(' trimesh.Scene:
+ """
+ Converts VGGT predictions to a 3D scene represented as a GLB file.
+
+ Args:
+ predictions (dict): Dictionary containing model predictions with keys:
+ - world_points: 3D point coordinates (S, H, W, 3)
+ - world_points_conf: Confidence scores (S, H, W)
+ - images: Input images (S, H, W, 3)
+ - extrinsic: Camera extrinsic matrices (S, 3, 4)
+ conf_thres (float): Percentage of low-confidence points to filter out (default: 50.0)
+ filter_by_frames (str): Frame filter specification (default: "all")
+ mask_black_bg (bool): Mask out black background pixels (default: False)
+ mask_white_bg (bool): Mask out white background pixels (default: False)
+ show_cam (bool): Include camera visualization (default: True)
+ mask_sky (bool): Apply sky segmentation mask (default: False)
+ target_dir (str): Output directory for intermediate files (default: None)
+ prediction_mode (str): Prediction mode selector (default: "Predicted Pointmap")
+
+ Returns:
+ trimesh.Scene: Processed 3D scene containing point cloud and cameras
+
+ Raises:
+ ValueError: If input predictions structure is invalid
+ """
+ if not isinstance(predictions, dict):
+ raise ValueError("predictions must be a dictionary")
+
+ if conf_thres is None:
+ conf_thres = 10.0
+
+ print("Building GLB scene")
+ selected_frame_idx = None
+ if filter_by_frames != "all" and filter_by_frames != "All":
+ try:
+ # Extract the index part before the colon
+ selected_frame_idx = int(filter_by_frames.split(":")[0])
+ except (ValueError, IndexError):
+ pass
+
+ if "Pointmap" in prediction_mode:
+ print("Using Pointmap Branch")
+ if "world_points" in predictions:
+ pred_world_points = predictions["world_points"] # No batch dimension to remove
+ pred_world_points_conf = predictions.get("world_points_conf", np.ones_like(pred_world_points[..., 0]))
+ else:
+ print("Warning: world_points not found in predictions, falling back to depth-based points")
+ pred_world_points = predictions["world_points_from_depth"]
+ pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0]))
+ else:
+ print("Using Depthmap and Camera Branch")
+ pred_world_points = predictions["world_points_from_depth"]
+ pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0]))
+
+ # Get images from predictions
+ images = predictions["images"]
+ # Use extrinsic matrices instead of pred_extrinsic_list
+ camera_matrices = predictions["extrinsic"]
+
+ if mask_sky:
+ if target_dir is not None:
+ import onnxruntime
+
+ skyseg_session = None
+ target_dir_images = target_dir + "/images"
+ image_list = sorted(os.listdir(target_dir_images))
+ sky_mask_list = []
+
+ # Get the shape of pred_world_points_conf to match
+ S, H, W = (
+ pred_world_points_conf.shape
+ if hasattr(pred_world_points_conf, "shape")
+ else (len(images), images.shape[1], images.shape[2])
+ )
+
+ # Download skyseg.onnx if it doesn't exist
+ if not os.path.exists("skyseg.onnx"):
+ print("Downloading skyseg.onnx...")
+ download_file_from_url(
+ "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", "skyseg.onnx"
+ )
+
+ for i, image_name in enumerate(image_list):
+ image_filepath = os.path.join(target_dir_images, image_name)
+ mask_filepath = os.path.join(target_dir, "sky_masks", image_name)
+
+ # Check if mask already exists
+ if os.path.exists(mask_filepath):
+ # Load existing mask
+ sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
+ else:
+ # Generate new mask
+ if skyseg_session is None:
+ skyseg_session = onnxruntime.InferenceSession("skyseg.onnx")
+ sky_mask = segment_sky(image_filepath, skyseg_session, mask_filepath)
+
+ # Resize mask to match H×W if needed
+ if sky_mask.shape[0] != H or sky_mask.shape[1] != W:
+ sky_mask = cv2.resize(sky_mask, (W, H))
+
+ sky_mask_list.append(sky_mask)
+
+ # Convert list to numpy array with shape S×H×W
+ sky_mask_array = np.array(sky_mask_list)
+
+ # Apply sky mask to confidence scores
+ sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32)
+ pred_world_points_conf = pred_world_points_conf * sky_mask_binary
+
+ if selected_frame_idx is not None:
+ pred_world_points = pred_world_points[selected_frame_idx][None]
+ pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None]
+ images = images[selected_frame_idx][None]
+ camera_matrices = camera_matrices[selected_frame_idx][None]
+
+ vertices_3d = pred_world_points.reshape(-1, 3)
+ # Handle different image formats - check if images need transposing
+ if images.ndim == 4 and images.shape[1] == 3: # NCHW format
+ colors_rgb = np.transpose(images, (0, 2, 3, 1))
+ else: # Assume already in NHWC format
+ colors_rgb = images
+ colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8)
+
+ conf = pred_world_points_conf.reshape(-1)
+ # Convert percentage threshold to actual confidence value
+ if conf_thres == 0.0:
+ conf_threshold = 0.0
+ else:
+ conf_threshold = np.percentile(conf, conf_thres)
+
+ conf_mask = (conf >= conf_threshold) & (conf > 1e-5)
+
+ if mask_black_bg:
+ black_bg_mask = colors_rgb.sum(axis=1) >= 16
+ conf_mask = conf_mask & black_bg_mask
+
+ if mask_white_bg:
+ # Filter out white background pixels (RGB values close to white)
+ # Consider pixels white if all RGB values are above 240
+ white_bg_mask = ~((colors_rgb[:, 0] > 240) & (colors_rgb[:, 1] > 240) & (colors_rgb[:, 2] > 240))
+ conf_mask = conf_mask & white_bg_mask
+
+ vertices_3d = vertices_3d[conf_mask]
+ colors_rgb = colors_rgb[conf_mask]
+
+ if vertices_3d is None or np.asarray(vertices_3d).size == 0:
+ vertices_3d = np.array([[1, 0, 0]])
+ colors_rgb = np.array([[255, 255, 255]])
+ scene_scale = 1
+ else:
+ # Calculate the 5th and 95th percentiles along each axis
+ lower_percentile = np.percentile(vertices_3d, 5, axis=0)
+ upper_percentile = np.percentile(vertices_3d, 95, axis=0)
+
+ # Calculate the diagonal length of the percentile bounding box
+ scene_scale = np.linalg.norm(upper_percentile - lower_percentile)
+
+ colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
+
+ # Initialize a 3D scene
+ scene_3d = trimesh.Scene()
+
+ # Add point cloud data to the scene
+ point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
+
+ scene_3d.add_geometry(point_cloud_data)
+
+ # Prepare 4x4 matrices for camera extrinsics
+ num_cameras = len(camera_matrices)
+ extrinsics_matrices = np.zeros((num_cameras, 4, 4))
+ extrinsics_matrices[:, :3, :4] = camera_matrices
+ extrinsics_matrices[:, 3, 3] = 1
+
+ if show_cam:
+ # Add camera models to the scene
+ for i in range(num_cameras):
+ world_to_camera = extrinsics_matrices[i]
+ camera_to_world = np.linalg.inv(world_to_camera)
+ rgba_color = colormap(i / num_cameras)
+ current_color = tuple(int(255 * x) for x in rgba_color[:3])
+
+ integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale)
+
+ # Align scene to the observation of the first camera
+ scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices)
+
+ print("GLB Scene built")
+ return scene_3d
+
+
+def integrate_camera_into_scene(
+ scene: trimesh.Scene,
+ transform: np.ndarray,
+ face_colors: tuple,
+ scene_scale: float,
+):
+ """
+ Integrates a fake camera mesh into the 3D scene.
+
+ Args:
+ scene (trimesh.Scene): The 3D scene to add the camera model.
+ transform (np.ndarray): Transformation matrix for camera positioning.
+ face_colors (tuple): Color of the camera face.
+ scene_scale (float): Scale of the scene.
+ """
+
+ cam_width = scene_scale * 0.05
+ cam_height = scene_scale * 0.1
+
+ # Create cone shape for camera
+ rot_45_degree = np.eye(4)
+ rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix()
+ rot_45_degree[2, 3] = -cam_height
+
+ opengl_transform = get_opengl_conversion_matrix()
+ # Combine transformations
+ complete_transform = transform @ opengl_transform @ rot_45_degree
+ camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4)
+
+ # Generate mesh for the camera
+ slight_rotation = np.eye(4)
+ slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix()
+
+ vertices_combined = np.concatenate(
+ [
+ camera_cone_shape.vertices,
+ 0.95 * camera_cone_shape.vertices,
+ transform_points(slight_rotation, camera_cone_shape.vertices),
+ ]
+ )
+ vertices_transformed = transform_points(complete_transform, vertices_combined)
+
+ mesh_faces = compute_camera_faces(camera_cone_shape)
+
+ # Add the camera mesh to the scene
+ camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces)
+ camera_mesh.visual.face_colors[:, :3] = face_colors
+ scene.add_geometry(camera_mesh)
+
+
+def apply_scene_alignment(scene_3d: trimesh.Scene, extrinsics_matrices: np.ndarray) -> trimesh.Scene:
+ """
+ Aligns the 3D scene based on the extrinsics of the first camera.
+
+ Args:
+ scene_3d (trimesh.Scene): The 3D scene to be aligned.
+ extrinsics_matrices (np.ndarray): Camera extrinsic matrices.
+
+ Returns:
+ trimesh.Scene: Aligned 3D scene.
+ """
+ # Set transformations for scene alignment
+ opengl_conversion_matrix = get_opengl_conversion_matrix()
+
+ # Rotation matrix for alignment (180 degrees around the y-axis)
+ align_rotation = np.eye(4)
+ align_rotation[:3, :3] = Rotation.from_euler("y", 180, degrees=True).as_matrix()
+
+ # Apply transformation
+ initial_transformation = np.linalg.inv(extrinsics_matrices[0]) @ opengl_conversion_matrix @ align_rotation
+ scene_3d.apply_transform(initial_transformation)
+ return scene_3d
+
+
+def get_opengl_conversion_matrix() -> np.ndarray:
+ """
+ Constructs and returns the OpenGL conversion matrix.
+
+ Returns:
+ numpy.ndarray: A 4x4 OpenGL conversion matrix.
+ """
+ # Create an identity matrix
+ matrix = np.identity(4)
+
+ # Flip the y and z axes
+ matrix[1, 1] = -1
+ matrix[2, 2] = -1
+
+ return matrix
+
+
+def transform_points(transformation: np.ndarray, points: np.ndarray, dim: int = None) -> np.ndarray:
+ """
+ Applies a 4x4 transformation to a set of points.
+
+ Args:
+ transformation (np.ndarray): Transformation matrix.
+ points (np.ndarray): Points to be transformed.
+ dim (int, optional): Dimension for reshaping the result.
+
+ Returns:
+ np.ndarray: Transformed points.
+ """
+ points = np.asarray(points)
+ initial_shape = points.shape[:-1]
+ dim = dim or points.shape[-1]
+
+ # Apply transformation
+ transformation = transformation.swapaxes(-1, -2) # Transpose the transformation matrix
+ points = points @ transformation[..., :-1, :] + transformation[..., -1:, :]
+
+ # Reshape the result
+ result = points[..., :dim].reshape(*initial_shape, dim)
+ return result
+
+
+def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
+ """
+ Computes the faces for the camera mesh.
+
+ Args:
+ cone_shape (trimesh.Trimesh): The shape of the camera cone.
+
+ Returns:
+ np.ndarray: Array of faces for the camera mesh.
+ """
+ # Create pseudo cameras
+ faces_list = []
+ num_vertices_cone = len(cone_shape.vertices)
+
+ for face in cone_shape.faces:
+ if 0 in face:
+ continue
+ v1, v2, v3 = face
+ v1_offset, v2_offset, v3_offset = face + num_vertices_cone
+ v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
+
+ faces_list.extend(
+ [
+ (v1, v2, v2_offset),
+ (v1, v1_offset, v3),
+ (v3_offset, v2, v3),
+ (v1, v2, v2_offset_2),
+ (v1, v1_offset_2, v3),
+ (v3_offset_2, v2, v3),
+ ]
+ )
+
+ faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
+ return np.array(faces_list)
+
+
+def segment_sky(image_path, onnx_session, mask_filename=None):
+ """
+ Segments sky from an image using an ONNX model.
+ Thanks for the great model provided by https://github.com/xiongzhu666/Sky-Segmentation-and-Post-processing
+
+ Args:
+ image_path: Path to input image
+ onnx_session: ONNX runtime session with loaded model
+ mask_filename: Path to save the output mask
+
+ Returns:
+ np.ndarray: Binary mask where 255 indicates non-sky regions
+ """
+
+ assert mask_filename is not None
+ image = cv2.imread(image_path)
+
+ result_map = run_skyseg(onnx_session, [320, 320], image)
+ # resize the result_map to the original image size
+ result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
+
+ # Fix: Invert the mask so that 255 = non-sky, 0 = sky
+ # The model outputs low values for sky, high values for non-sky
+ output_mask = np.zeros_like(result_map_original)
+ output_mask[result_map_original < 32] = 255 # Use threshold of 32
+
+ os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
+ cv2.imwrite(mask_filename, output_mask)
+ return output_mask
+
+
+def run_skyseg(onnx_session, input_size, image):
+ """
+ Runs sky segmentation inference using ONNX model.
+
+ Args:
+ onnx_session: ONNX runtime session
+ input_size: Target size for model input (width, height)
+ image: Input image in BGR format
+
+ Returns:
+ np.ndarray: Segmentation mask
+ """
+
+ # Pre process:Resize, BGR->RGB, Transpose, PyTorch standardization, float32 cast
+ temp_image = copy.deepcopy(image)
+ resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
+ x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
+ x = np.array(x, dtype=np.float32)
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
+ x = (x / 255 - mean) / std
+ x = x.transpose(2, 0, 1)
+ x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
+
+ # Inference
+ input_name = onnx_session.get_inputs()[0].name
+ output_name = onnx_session.get_outputs()[0].name
+ onnx_result = onnx_session.run([output_name], {input_name: x})
+
+ # Post process
+ onnx_result = np.array(onnx_result).squeeze()
+ min_value = np.min(onnx_result)
+ max_value = np.max(onnx_result)
+ onnx_result = (onnx_result - min_value) / (max_value - min_value)
+ onnx_result *= 255
+ onnx_result = onnx_result.astype("uint8")
+
+ return onnx_result
+
+
+def download_file_from_url(url, filename):
+ """Downloads a file from a Hugging Face model repo, handling redirects."""
+ try:
+ # Get the redirect URL
+ response = requests.get(url, allow_redirects=False)
+ response.raise_for_status() # Raise HTTPError for bad requests (4xx or 5xx)
+
+ if response.status_code == 302: # Expecting a redirect
+ redirect_url = response.headers["Location"]
+ response = requests.get(redirect_url, stream=True)
+ response.raise_for_status()
+ else:
+ print(f"Unexpected status code: {response.status_code}")
+ return
+
+ with open(filename, "wb") as f:
+ for chunk in response.iter_content(chunk_size=8192):
+ f.write(chunk)
+ print(f"Downloaded {filename} successfully.")
+
+ except requests.exceptions.RequestException as e:
+ print(f"Error downloading file: {e}")