Spaces:
Running
on
Zero
Running
on
Zero
update to the published ver
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +2 -2
- streamvggt/heads/camera_head.py +175 -0
- streamvggt/heads/dpt_head.py +472 -0
- vggt/heads/__pycache__/head_act.cpython-310.pyc → streamvggt/heads/head_act.py +0 -0
- vggt/heads/__pycache__/track_head.cpython-310.pyc → streamvggt/heads/track_head.py +0 -0
- streamvggt/heads/track_modules/__init__.py +0 -0
- streamvggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc +0 -0
- streamvggt/heads/track_modules/__pycache__/__init__.cpython-311.pyc +0 -0
- streamvggt/heads/track_modules/__pycache__/__init__.cpython-312.pyc +0 -0
- streamvggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc +0 -0
- streamvggt/heads/track_modules/__pycache__/base_track_predictor.cpython-311.pyc +0 -0
- streamvggt/heads/track_modules/__pycache__/base_track_predictor.cpython-312.pyc +0 -0
- streamvggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc +0 -0
- streamvggt/heads/track_modules/__pycache__/blocks.cpython-311.pyc +0 -0
- streamvggt/heads/track_modules/__pycache__/blocks.cpython-312.pyc +0 -0
- streamvggt/heads/track_modules/__pycache__/modules.cpython-310.pyc +0 -0
- streamvggt/heads/track_modules/__pycache__/modules.cpython-311.pyc +0 -0
- streamvggt/heads/track_modules/__pycache__/modules.cpython-312.pyc +0 -0
- streamvggt/heads/track_modules/__pycache__/utils.cpython-310.pyc +0 -0
- streamvggt/heads/track_modules/__pycache__/utils.cpython-311.pyc +0 -0
- streamvggt/heads/track_modules/__pycache__/utils.cpython-312.pyc +0 -0
- streamvggt/heads/track_modules/base_track_predictor.py +195 -0
- streamvggt/heads/track_modules/blocks.py +237 -0
- streamvggt/heads/track_modules/modules.py +211 -0
- streamvggt/heads/track_modules/utils.py +216 -0
- vggt/heads/__pycache__/utils.cpython-310.pyc → streamvggt/heads/utils.py +0 -0
- streamvggt/layers/__init__.py +5 -0
- streamvggt/layers/attention.py +129 -0
- streamvggt/layers/block.py +263 -0
- streamvggt/layers/drop_path.py +24 -0
- streamvggt/layers/layer_scale.py +20 -0
- streamvggt/layers/mlp.py +30 -0
- streamvggt/layers/patch_embed.py +79 -0
- vggt/layers/__pycache__/rope.cpython-310.pyc → streamvggt/layers/rope.py +0 -0
- streamvggt/layers/swiglu_ffn.py +67 -0
- streamvggt/layers/vision_transformer.py +398 -0
- streamvggt/models/aggregator.py +394 -0
- streamvggt/models/streamvggt.py +172 -0
- streamvggt/utils/geometry.py +166 -0
- streamvggt/utils/load_fn.py +146 -0
- vggt/utils/__pycache__/pose_enc.cpython-310.pyc → streamvggt/utils/pose_enc.py +0 -0
- streamvggt/utils/rotation.py +138 -0
- streamvggt/utils/visual_track.py +239 -0
- vggt/heads/__pycache__/camera_head.cpython-310.pyc +0 -0
- vggt/heads/__pycache__/camera_head.cpython-311.pyc +0 -0
- vggt/heads/__pycache__/camera_head.cpython-312.pyc +0 -0
- vggt/heads/__pycache__/dpt_head.cpython-310.pyc +0 -0
- vggt/heads/__pycache__/dpt_head.cpython-311.pyc +0 -0
- vggt/heads/__pycache__/dpt_head.cpython-312.pyc +0 -0
- vggt/heads/__pycache__/head_act.cpython-311.pyc +0 -0
app.py
CHANGED
@@ -29,7 +29,7 @@ import gc
|
|
29 |
import time
|
30 |
|
31 |
from visual_util import predictions_to_glb
|
32 |
-
from
|
33 |
from vggt.utils.load_fn import load_and_preprocess_images
|
34 |
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
35 |
from vggt.utils.geometry import unproject_depth_map_to_point_map
|
@@ -45,7 +45,7 @@ path = hf_hub_download(
|
|
45 |
revision="main",
|
46 |
force_download=True
|
47 |
)
|
48 |
-
model =
|
49 |
ckpt = torch.load(path, map_location=device)
|
50 |
model.load_state_dict(ckpt, strict=True)
|
51 |
model = model.to(device)
|
|
|
29 |
import time
|
30 |
|
31 |
from visual_util import predictions_to_glb
|
32 |
+
from streamvggt.models.streamvggt import StreamVGGT
|
33 |
from vggt.utils.load_fn import load_and_preprocess_images
|
34 |
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
35 |
from vggt.utils.geometry import unproject_depth_map_to_point_map
|
|
|
45 |
revision="main",
|
46 |
force_download=True
|
47 |
)
|
48 |
+
model = StreamVGGT()
|
49 |
ckpt = torch.load(path, map_location=device)
|
50 |
model.load_state_dict(ckpt, strict=True)
|
51 |
model = model.to(device)
|
streamvggt/heads/camera_head.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from streamvggt.layers import Mlp
|
9 |
+
from streamvggt.layers.block import Block
|
10 |
+
from streamvggt.heads.head_act import activate_pose
|
11 |
+
|
12 |
+
|
13 |
+
class CameraHead(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
dim_in: int = 2048,
|
17 |
+
trunk_depth: int = 4,
|
18 |
+
pose_encoding_type: str = "absT_quaR_FoV",
|
19 |
+
num_heads: int = 16,
|
20 |
+
mlp_ratio: int = 4,
|
21 |
+
init_values: float = 0.01,
|
22 |
+
trans_act: str = "linear",
|
23 |
+
quat_act: str = "linear",
|
24 |
+
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
if pose_encoding_type == "absT_quaR_FoV":
|
29 |
+
self.target_dim = 9
|
30 |
+
else:
|
31 |
+
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
|
32 |
+
|
33 |
+
self.trans_act = trans_act
|
34 |
+
self.quat_act = quat_act
|
35 |
+
self.fl_act = fl_act
|
36 |
+
self.trunk_depth = trunk_depth
|
37 |
+
|
38 |
+
# Build the trunk using a sequence of transformer blocks.
|
39 |
+
self.trunk = nn.Sequential(
|
40 |
+
*[
|
41 |
+
Block(
|
42 |
+
dim=dim_in,
|
43 |
+
num_heads=num_heads,
|
44 |
+
mlp_ratio=mlp_ratio,
|
45 |
+
init_values=init_values,
|
46 |
+
)
|
47 |
+
for _ in range(trunk_depth)
|
48 |
+
]
|
49 |
+
)
|
50 |
+
|
51 |
+
# Normalizations for camera token and trunk output.
|
52 |
+
self.token_norm = nn.LayerNorm(dim_in)
|
53 |
+
self.trunk_norm = nn.LayerNorm(dim_in)
|
54 |
+
|
55 |
+
# Learnable empty camera pose token.
|
56 |
+
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
57 |
+
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
58 |
+
|
59 |
+
# Module for producing modulation parameters: shift, scale, and a gate.
|
60 |
+
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
|
61 |
+
|
62 |
+
# Adaptive layer normalization without affine parameters.
|
63 |
+
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
64 |
+
self.pose_branch = Mlp(
|
65 |
+
in_features=dim_in,
|
66 |
+
hidden_features=dim_in // 2,
|
67 |
+
out_features=self.target_dim,
|
68 |
+
drop=0,
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4, past_key_values_camera = None, use_cache: bool = False) -> list:
|
72 |
+
"""
|
73 |
+
Forward pass to predict camera parameters.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
aggregated_tokens_list (list): List of token tensors from the network;
|
77 |
+
the last tensor is used for prediction.
|
78 |
+
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
list: A list of predicted camera encodings (post-activation) from each iteration.
|
82 |
+
"""
|
83 |
+
# Use tokens from the last block for camera prediction.
|
84 |
+
tokens = aggregated_tokens_list[-1]
|
85 |
+
|
86 |
+
# Extract the camera tokens
|
87 |
+
pose_tokens = tokens[:, :, 0]
|
88 |
+
pose_tokens = self.token_norm(pose_tokens)
|
89 |
+
|
90 |
+
if use_cache:
|
91 |
+
pred_pose_enc_list, past_key_values_camera = self.trunk_fn(pose_tokens, num_iterations, past_key_values_camera, use_cache)
|
92 |
+
return pred_pose_enc_list, past_key_values_camera
|
93 |
+
else:
|
94 |
+
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations, past_key_values_camera=None, use_cache=use_cache)
|
95 |
+
return pred_pose_enc_list
|
96 |
+
|
97 |
+
def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int, past_key_values_camera, use_cache: bool) -> list:
|
98 |
+
"""
|
99 |
+
Iteratively refine camera pose predictions.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
|
103 |
+
num_iterations (int): Number of refinement iterations.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
list: List of activated camera encodings from each iteration.
|
107 |
+
"""
|
108 |
+
B, S, C = pose_tokens.shape # S is expected to be 1.
|
109 |
+
pred_pose_enc = None
|
110 |
+
pred_pose_enc_list = []
|
111 |
+
|
112 |
+
for _ in range(num_iterations):
|
113 |
+
# Use a learned empty pose for the first iteration.
|
114 |
+
if pred_pose_enc is None:
|
115 |
+
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
116 |
+
else:
|
117 |
+
# Detach the previous prediction to avoid backprop through time.
|
118 |
+
pred_pose_enc = pred_pose_enc.detach()
|
119 |
+
module_input = self.embed_pose(pred_pose_enc)
|
120 |
+
|
121 |
+
# Generate modulation parameters and split them into shift, scale, and gate components.
|
122 |
+
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
|
123 |
+
|
124 |
+
# Adaptive layer normalization and modulation.
|
125 |
+
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
|
126 |
+
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
127 |
+
|
128 |
+
if not use_cache:
|
129 |
+
L = S * 1
|
130 |
+
frame_ids = torch.arange(L, device=pose_tokens_modulated.device) // 1 # [0,0,...,1,1,...,S-1]
|
131 |
+
future_frame = frame_ids.unsqueeze(1) < frame_ids.unsqueeze(0)
|
132 |
+
attn_mask = future_frame.to(pose_tokens_modulated.dtype) * torch.finfo(pose_tokens_modulated.dtype).min
|
133 |
+
else:
|
134 |
+
attn_mask = None
|
135 |
+
|
136 |
+
if use_cache:
|
137 |
+
for idx in range(self.trunk_depth):
|
138 |
+
pose_tokens_modulated, block_kv = self.trunk[idx](
|
139 |
+
pose_tokens_modulated,
|
140 |
+
attn_mask=attn_mask,
|
141 |
+
past_key_values=past_key_values_camera[idx] if past_key_values_camera[idx] is not None else None,
|
142 |
+
use_cache=True
|
143 |
+
)
|
144 |
+
past_key_values_camera[idx] = block_kv
|
145 |
+
else:
|
146 |
+
for idx in range(self.trunk_depth):
|
147 |
+
pose_tokens_modulated = self.trunk[idx](pose_tokens_modulated, attn_mask=attn_mask)
|
148 |
+
|
149 |
+
# Compute the delta update for the pose encoding.
|
150 |
+
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
|
151 |
+
|
152 |
+
if pred_pose_enc is None:
|
153 |
+
pred_pose_enc = pred_pose_enc_delta
|
154 |
+
else:
|
155 |
+
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
156 |
+
|
157 |
+
# Apply final activation functions for translation, quaternion, and field-of-view.
|
158 |
+
activated_pose = activate_pose(
|
159 |
+
pred_pose_enc,
|
160 |
+
trans_act=self.trans_act,
|
161 |
+
quat_act=self.quat_act,
|
162 |
+
fl_act=self.fl_act,
|
163 |
+
)
|
164 |
+
pred_pose_enc_list.append(activated_pose)
|
165 |
+
|
166 |
+
if use_cache:
|
167 |
+
return pred_pose_enc_list, past_key_values_camera
|
168 |
+
return pred_pose_enc_list
|
169 |
+
|
170 |
+
|
171 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
172 |
+
"""
|
173 |
+
Modulate the input tensor using scaling and shifting parameters.
|
174 |
+
"""
|
175 |
+
return x * (1 + scale) + shift
|
streamvggt/heads/dpt_head.py
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Dict, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from .head_act import activate_head
|
8 |
+
from .utils import create_uv_grid, position_grid_to_embed
|
9 |
+
|
10 |
+
|
11 |
+
class DPTHead(nn.Module):
|
12 |
+
"""
|
13 |
+
Args:
|
14 |
+
dim_in (int): Input dimension (channels).
|
15 |
+
patch_size (int, optional): Patch size. Default is 14.
|
16 |
+
output_dim (int, optional): Number of output channels. Default is 4.
|
17 |
+
activation (str, optional): Activation type. Default is "inv_log".
|
18 |
+
conf_activation (str, optional): Confidence activation type. Default is "expp1".
|
19 |
+
features (int, optional): Feature channels for intermediate representations. Default is 256.
|
20 |
+
out_channels (List[int], optional): Output channels for each intermediate layer.
|
21 |
+
intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
|
22 |
+
pos_embed (bool, optional): Whether to use positional embedding. Default is True.
|
23 |
+
feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
|
24 |
+
down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
dim_in: int,
|
30 |
+
patch_size: int = 14,
|
31 |
+
output_dim: int = 4,
|
32 |
+
activation: str = "inv_log",
|
33 |
+
conf_activation: str = "expp1",
|
34 |
+
features: int = 256,
|
35 |
+
out_channels: List[int] = [256, 512, 1024, 1024],
|
36 |
+
intermediate_layer_idx: List[int] = [4, 11, 17, 23],
|
37 |
+
pos_embed: bool = True,
|
38 |
+
feature_only: bool = False,
|
39 |
+
down_ratio: int = 1,
|
40 |
+
) -> None:
|
41 |
+
super(DPTHead, self).__init__()
|
42 |
+
self.patch_size = patch_size
|
43 |
+
self.activation = activation
|
44 |
+
self.conf_activation = conf_activation
|
45 |
+
self.pos_embed = pos_embed
|
46 |
+
self.feature_only = feature_only
|
47 |
+
self.down_ratio = down_ratio
|
48 |
+
self.intermediate_layer_idx = intermediate_layer_idx
|
49 |
+
|
50 |
+
self.norm = nn.LayerNorm(dim_in)
|
51 |
+
|
52 |
+
# Projection layers for each output channel from tokens.
|
53 |
+
self.projects = nn.ModuleList(
|
54 |
+
[
|
55 |
+
nn.Conv2d(
|
56 |
+
in_channels=dim_in,
|
57 |
+
out_channels=oc,
|
58 |
+
kernel_size=1,
|
59 |
+
stride=1,
|
60 |
+
padding=0,
|
61 |
+
)
|
62 |
+
for oc in out_channels
|
63 |
+
]
|
64 |
+
)
|
65 |
+
|
66 |
+
# Resize layers for upsampling feature maps.
|
67 |
+
self.resize_layers = nn.ModuleList(
|
68 |
+
[
|
69 |
+
nn.ConvTranspose2d(
|
70 |
+
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
71 |
+
),
|
72 |
+
nn.ConvTranspose2d(
|
73 |
+
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
74 |
+
),
|
75 |
+
nn.Identity(),
|
76 |
+
nn.Conv2d(
|
77 |
+
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
78 |
+
),
|
79 |
+
]
|
80 |
+
)
|
81 |
+
|
82 |
+
self.scratch = _make_scratch(
|
83 |
+
out_channels,
|
84 |
+
features,
|
85 |
+
expand=False,
|
86 |
+
)
|
87 |
+
|
88 |
+
# Attach additional modules to scratch.
|
89 |
+
self.scratch.stem_transpose = None
|
90 |
+
self.scratch.refinenet1 = _make_fusion_block(features)
|
91 |
+
self.scratch.refinenet2 = _make_fusion_block(features)
|
92 |
+
self.scratch.refinenet3 = _make_fusion_block(features)
|
93 |
+
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
|
94 |
+
|
95 |
+
head_features_1 = features
|
96 |
+
head_features_2 = 32
|
97 |
+
|
98 |
+
if feature_only:
|
99 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
|
100 |
+
else:
|
101 |
+
self.scratch.output_conv1 = nn.Conv2d(
|
102 |
+
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
103 |
+
)
|
104 |
+
conv2_in_channels = head_features_1 // 2
|
105 |
+
|
106 |
+
self.scratch.output_conv2 = nn.Sequential(
|
107 |
+
nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
|
108 |
+
nn.ReLU(inplace=True),
|
109 |
+
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
|
110 |
+
)
|
111 |
+
|
112 |
+
def forward(
|
113 |
+
self,
|
114 |
+
aggregated_tokens_list: List[torch.Tensor],
|
115 |
+
images: torch.Tensor,
|
116 |
+
patch_start_idx: int,
|
117 |
+
frames_chunk_size: int = 8,
|
118 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
119 |
+
"""
|
120 |
+
Forward pass through the DPT head, supports processing by chunking frames.
|
121 |
+
Args:
|
122 |
+
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
123 |
+
images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
124 |
+
patch_start_idx (int): Starting index for patch tokens in the token sequence.
|
125 |
+
Used to separate patch tokens from other tokens (e.g., camera or register tokens).
|
126 |
+
frames_chunk_size (int, optional): Number of frames to process in each chunk.
|
127 |
+
If None or larger than S, all frames are processed at once. Default: 8.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
Tensor or Tuple[Tensor, Tensor]:
|
131 |
+
- If feature_only=True: Feature maps with shape [B, S, C, H, W]
|
132 |
+
- Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
|
133 |
+
"""
|
134 |
+
B, S, _, H, W = images.shape
|
135 |
+
|
136 |
+
# If frames_chunk_size is not specified or greater than S, process all frames at once
|
137 |
+
if frames_chunk_size is None or frames_chunk_size >= S:
|
138 |
+
return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
|
139 |
+
|
140 |
+
# Otherwise, process frames in chunks to manage memory usage
|
141 |
+
assert frames_chunk_size > 0
|
142 |
+
|
143 |
+
# Process frames in batches
|
144 |
+
all_preds = []
|
145 |
+
all_conf = []
|
146 |
+
|
147 |
+
for frames_start_idx in range(0, S, frames_chunk_size):
|
148 |
+
frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
|
149 |
+
|
150 |
+
# Process batch of frames
|
151 |
+
if self.feature_only:
|
152 |
+
chunk_output = self._forward_impl(
|
153 |
+
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
|
154 |
+
)
|
155 |
+
all_preds.append(chunk_output)
|
156 |
+
else:
|
157 |
+
chunk_preds, chunk_conf = self._forward_impl(
|
158 |
+
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
|
159 |
+
)
|
160 |
+
all_preds.append(chunk_preds)
|
161 |
+
all_conf.append(chunk_conf)
|
162 |
+
|
163 |
+
# Concatenate results along the sequence dimension
|
164 |
+
if self.feature_only:
|
165 |
+
return torch.cat(all_preds, dim=1)
|
166 |
+
else:
|
167 |
+
return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
|
168 |
+
|
169 |
+
def _forward_impl(
|
170 |
+
self,
|
171 |
+
aggregated_tokens_list: List[torch.Tensor],
|
172 |
+
images: torch.Tensor,
|
173 |
+
patch_start_idx: int,
|
174 |
+
frames_start_idx: int = None,
|
175 |
+
frames_end_idx: int = None,
|
176 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
177 |
+
"""
|
178 |
+
Args:
|
179 |
+
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
180 |
+
images (Tensor): Input images with shape [B, S, 3, H, W].
|
181 |
+
patch_start_idx (int): Starting index for patch tokens.
|
182 |
+
frames_start_idx (int, optional): Starting index for frames to process.
|
183 |
+
frames_end_idx (int, optional): Ending index for frames to process.
|
184 |
+
|
185 |
+
Returns:
|
186 |
+
Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
|
187 |
+
"""
|
188 |
+
if frames_start_idx is not None and frames_end_idx is not None:
|
189 |
+
images = images[:, frames_start_idx:frames_end_idx].contiguous()
|
190 |
+
|
191 |
+
B, S, _, H, W = images.shape
|
192 |
+
|
193 |
+
patch_h, patch_w = H // self.patch_size, W // self.patch_size
|
194 |
+
|
195 |
+
out = []
|
196 |
+
dpt_idx = 0
|
197 |
+
|
198 |
+
for layer_idx in self.intermediate_layer_idx:
|
199 |
+
x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
|
200 |
+
|
201 |
+
# Select frames if processing a chunk
|
202 |
+
if frames_start_idx is not None and frames_end_idx is not None:
|
203 |
+
x = x[:, frames_start_idx:frames_end_idx]
|
204 |
+
|
205 |
+
x = x.reshape(B * S, -1, x.shape[-1])
|
206 |
+
|
207 |
+
x = self.norm(x)
|
208 |
+
|
209 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
210 |
+
|
211 |
+
x = self.projects[dpt_idx](x)
|
212 |
+
if self.pos_embed:
|
213 |
+
x = self._apply_pos_embed(x, W, H)
|
214 |
+
x = self.resize_layers[dpt_idx](x)
|
215 |
+
|
216 |
+
out.append(x)
|
217 |
+
dpt_idx += 1
|
218 |
+
|
219 |
+
# Fuse features from multiple layers.
|
220 |
+
out = self.scratch_forward(out)
|
221 |
+
# Interpolate fused output to match target image resolution.
|
222 |
+
out = custom_interpolate(
|
223 |
+
out,
|
224 |
+
(int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
|
225 |
+
mode="bilinear",
|
226 |
+
align_corners=True,
|
227 |
+
)
|
228 |
+
|
229 |
+
if self.pos_embed:
|
230 |
+
out = self._apply_pos_embed(out, W, H)
|
231 |
+
|
232 |
+
if self.feature_only:
|
233 |
+
return out.reshape(B, S, *out.shape[1:])
|
234 |
+
|
235 |
+
out = self.scratch.output_conv2(out)
|
236 |
+
preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
|
237 |
+
|
238 |
+
preds = preds.reshape(B, S, *preds.shape[1:])
|
239 |
+
conf = conf.reshape(B, S, *conf.shape[1:])
|
240 |
+
return preds, conf
|
241 |
+
|
242 |
+
def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
243 |
+
"""
|
244 |
+
Apply positional embedding to tensor x.
|
245 |
+
"""
|
246 |
+
patch_w = x.shape[-1]
|
247 |
+
patch_h = x.shape[-2]
|
248 |
+
pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
249 |
+
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
|
250 |
+
pos_embed = pos_embed * ratio
|
251 |
+
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
252 |
+
return x + pos_embed
|
253 |
+
|
254 |
+
def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
|
255 |
+
"""
|
256 |
+
Forward pass through the fusion blocks.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
features (List[Tensor]): List of feature maps from different layers.
|
260 |
+
|
261 |
+
Returns:
|
262 |
+
Tensor: Fused feature map.
|
263 |
+
"""
|
264 |
+
layer_1, layer_2, layer_3, layer_4 = features
|
265 |
+
|
266 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
267 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
268 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
269 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
270 |
+
|
271 |
+
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
272 |
+
del layer_4_rn, layer_4
|
273 |
+
|
274 |
+
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
|
275 |
+
del layer_3_rn, layer_3
|
276 |
+
|
277 |
+
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
|
278 |
+
del layer_2_rn, layer_2
|
279 |
+
|
280 |
+
out = self.scratch.refinenet1(out, layer_1_rn)
|
281 |
+
del layer_1_rn, layer_1
|
282 |
+
|
283 |
+
out = self.scratch.output_conv1(out)
|
284 |
+
return out
|
285 |
+
|
286 |
+
|
287 |
+
def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
|
288 |
+
return FeatureFusionBlock(
|
289 |
+
features,
|
290 |
+
nn.ReLU(inplace=True),
|
291 |
+
deconv=False,
|
292 |
+
bn=False,
|
293 |
+
expand=False,
|
294 |
+
align_corners=True,
|
295 |
+
size=size,
|
296 |
+
has_residual=has_residual,
|
297 |
+
groups=groups,
|
298 |
+
)
|
299 |
+
|
300 |
+
|
301 |
+
def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
|
302 |
+
scratch = nn.Module()
|
303 |
+
out_shape1 = out_shape
|
304 |
+
out_shape2 = out_shape
|
305 |
+
out_shape3 = out_shape
|
306 |
+
if len(in_shape) >= 4:
|
307 |
+
out_shape4 = out_shape
|
308 |
+
|
309 |
+
if expand:
|
310 |
+
out_shape1 = out_shape
|
311 |
+
out_shape2 = out_shape * 2
|
312 |
+
out_shape3 = out_shape * 4
|
313 |
+
if len(in_shape) >= 4:
|
314 |
+
out_shape4 = out_shape * 8
|
315 |
+
|
316 |
+
scratch.layer1_rn = nn.Conv2d(
|
317 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
318 |
+
)
|
319 |
+
scratch.layer2_rn = nn.Conv2d(
|
320 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
321 |
+
)
|
322 |
+
scratch.layer3_rn = nn.Conv2d(
|
323 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
324 |
+
)
|
325 |
+
if len(in_shape) >= 4:
|
326 |
+
scratch.layer4_rn = nn.Conv2d(
|
327 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
328 |
+
)
|
329 |
+
return scratch
|
330 |
+
|
331 |
+
|
332 |
+
class ResidualConvUnit(nn.Module):
|
333 |
+
"""Residual convolution module."""
|
334 |
+
|
335 |
+
def __init__(self, features, activation, bn, groups=1):
|
336 |
+
"""Init.
|
337 |
+
|
338 |
+
Args:
|
339 |
+
features (int): number of features
|
340 |
+
"""
|
341 |
+
super().__init__()
|
342 |
+
|
343 |
+
self.bn = bn
|
344 |
+
self.groups = groups
|
345 |
+
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
346 |
+
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
347 |
+
|
348 |
+
self.norm1 = None
|
349 |
+
self.norm2 = None
|
350 |
+
|
351 |
+
self.activation = activation
|
352 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
353 |
+
|
354 |
+
def forward(self, x):
|
355 |
+
"""Forward pass.
|
356 |
+
|
357 |
+
Args:
|
358 |
+
x (tensor): input
|
359 |
+
|
360 |
+
Returns:
|
361 |
+
tensor: output
|
362 |
+
"""
|
363 |
+
|
364 |
+
out = self.activation(x)
|
365 |
+
out = self.conv1(out)
|
366 |
+
if self.norm1 is not None:
|
367 |
+
out = self.norm1(out)
|
368 |
+
|
369 |
+
out = self.activation(out)
|
370 |
+
out = self.conv2(out)
|
371 |
+
if self.norm2 is not None:
|
372 |
+
out = self.norm2(out)
|
373 |
+
|
374 |
+
return self.skip_add.add(out, x)
|
375 |
+
|
376 |
+
|
377 |
+
class FeatureFusionBlock(nn.Module):
|
378 |
+
"""Feature fusion block."""
|
379 |
+
|
380 |
+
def __init__(
|
381 |
+
self,
|
382 |
+
features,
|
383 |
+
activation,
|
384 |
+
deconv=False,
|
385 |
+
bn=False,
|
386 |
+
expand=False,
|
387 |
+
align_corners=True,
|
388 |
+
size=None,
|
389 |
+
has_residual=True,
|
390 |
+
groups=1,
|
391 |
+
):
|
392 |
+
"""Init.
|
393 |
+
|
394 |
+
Args:
|
395 |
+
features (int): number of features
|
396 |
+
"""
|
397 |
+
super(FeatureFusionBlock, self).__init__()
|
398 |
+
|
399 |
+
self.deconv = deconv
|
400 |
+
self.align_corners = align_corners
|
401 |
+
self.groups = groups
|
402 |
+
self.expand = expand
|
403 |
+
out_features = features
|
404 |
+
if self.expand == True:
|
405 |
+
out_features = features // 2
|
406 |
+
|
407 |
+
self.out_conv = nn.Conv2d(
|
408 |
+
features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
|
409 |
+
)
|
410 |
+
|
411 |
+
if has_residual:
|
412 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
413 |
+
|
414 |
+
self.has_residual = has_residual
|
415 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
416 |
+
|
417 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
418 |
+
self.size = size
|
419 |
+
|
420 |
+
def forward(self, *xs, size=None):
|
421 |
+
"""Forward pass.
|
422 |
+
|
423 |
+
Returns:
|
424 |
+
tensor: output
|
425 |
+
"""
|
426 |
+
output = xs[0]
|
427 |
+
|
428 |
+
if self.has_residual:
|
429 |
+
res = self.resConfUnit1(xs[1])
|
430 |
+
output = self.skip_add.add(output, res)
|
431 |
+
|
432 |
+
output = self.resConfUnit2(output)
|
433 |
+
|
434 |
+
if (size is None) and (self.size is None):
|
435 |
+
modifier = {"scale_factor": 2}
|
436 |
+
elif size is None:
|
437 |
+
modifier = {"size": self.size}
|
438 |
+
else:
|
439 |
+
modifier = {"size": size}
|
440 |
+
|
441 |
+
output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
442 |
+
output = self.out_conv(output)
|
443 |
+
|
444 |
+
return output
|
445 |
+
|
446 |
+
|
447 |
+
def custom_interpolate(
|
448 |
+
x: torch.Tensor,
|
449 |
+
size: Tuple[int, int] = None,
|
450 |
+
scale_factor: float = None,
|
451 |
+
mode: str = "bilinear",
|
452 |
+
align_corners: bool = True,
|
453 |
+
) -> torch.Tensor:
|
454 |
+
"""
|
455 |
+
Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
|
456 |
+
"""
|
457 |
+
if size is None:
|
458 |
+
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
459 |
+
|
460 |
+
INT_MAX = 1610612736
|
461 |
+
|
462 |
+
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
|
463 |
+
|
464 |
+
if input_elements > INT_MAX:
|
465 |
+
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
|
466 |
+
interpolated_chunks = [
|
467 |
+
nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
|
468 |
+
]
|
469 |
+
x = torch.cat(interpolated_chunks, dim=0)
|
470 |
+
return x.contiguous()
|
471 |
+
else:
|
472 |
+
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
vggt/heads/__pycache__/head_act.cpython-310.pyc → streamvggt/heads/head_act.py
RENAMED
Binary files a/vggt/heads/__pycache__/head_act.cpython-310.pyc and b/streamvggt/heads/head_act.py differ
|
|
vggt/heads/__pycache__/track_head.cpython-310.pyc → streamvggt/heads/track_head.py
RENAMED
Binary files a/vggt/heads/__pycache__/track_head.cpython-310.pyc and b/streamvggt/heads/track_head.py differ
|
|
streamvggt/heads/track_modules/__init__.py
ADDED
File without changes
|
streamvggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (157 Bytes). View file
|
|
streamvggt/heads/track_modules/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (175 Bytes). View file
|
|
streamvggt/heads/track_modules/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (165 Bytes). View file
|
|
streamvggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc
ADDED
Binary file (4.27 kB). View file
|
|
streamvggt/heads/track_modules/__pycache__/base_track_predictor.cpython-311.pyc
ADDED
Binary file (9.38 kB). View file
|
|
streamvggt/heads/track_modules/__pycache__/base_track_predictor.cpython-312.pyc
ADDED
Binary file (8.77 kB). View file
|
|
streamvggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc
ADDED
Binary file (6.58 kB). View file
|
|
streamvggt/heads/track_modules/__pycache__/blocks.cpython-311.pyc
ADDED
Binary file (12.9 kB). View file
|
|
streamvggt/heads/track_modules/__pycache__/blocks.cpython-312.pyc
ADDED
Binary file (11.6 kB). View file
|
|
streamvggt/heads/track_modules/__pycache__/modules.cpython-310.pyc
ADDED
Binary file (5.27 kB). View file
|
|
streamvggt/heads/track_modules/__pycache__/modules.cpython-311.pyc
ADDED
Binary file (10 kB). View file
|
|
streamvggt/heads/track_modules/__pycache__/modules.cpython-312.pyc
ADDED
Binary file (8.79 kB). View file
|
|
streamvggt/heads/track_modules/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (7.36 kB). View file
|
|
streamvggt/heads/track_modules/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (11.1 kB). View file
|
|
streamvggt/heads/track_modules/__pycache__/utils.cpython-312.pyc
ADDED
Binary file (10.4 kB). View file
|
|
streamvggt/heads/track_modules/base_track_predictor.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from einops import rearrange, repeat
|
4 |
+
|
5 |
+
|
6 |
+
from .blocks import EfficientUpdateFormer, CorrBlock
|
7 |
+
from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
|
8 |
+
from .modules import Mlp
|
9 |
+
|
10 |
+
|
11 |
+
class BaseTrackerPredictor(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
stride=1,
|
15 |
+
corr_levels=5,
|
16 |
+
corr_radius=4,
|
17 |
+
latent_dim=128,
|
18 |
+
hidden_size=384,
|
19 |
+
use_spaceatt=True,
|
20 |
+
depth=6,
|
21 |
+
max_scale=518,
|
22 |
+
predict_conf=True,
|
23 |
+
):
|
24 |
+
super(BaseTrackerPredictor, self).__init__()
|
25 |
+
self.stride = stride
|
26 |
+
self.latent_dim = latent_dim
|
27 |
+
self.corr_levels = corr_levels
|
28 |
+
self.corr_radius = corr_radius
|
29 |
+
self.hidden_size = hidden_size
|
30 |
+
self.max_scale = max_scale
|
31 |
+
self.predict_conf = predict_conf
|
32 |
+
|
33 |
+
self.flows_emb_dim = latent_dim // 2
|
34 |
+
|
35 |
+
self.corr_mlp = Mlp(
|
36 |
+
in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
|
37 |
+
hidden_features=self.hidden_size,
|
38 |
+
out_features=self.latent_dim,
|
39 |
+
)
|
40 |
+
|
41 |
+
self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
|
42 |
+
|
43 |
+
self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
|
44 |
+
|
45 |
+
space_depth = depth if use_spaceatt else 0
|
46 |
+
time_depth = depth
|
47 |
+
|
48 |
+
self.updateformer = EfficientUpdateFormer(
|
49 |
+
space_depth=space_depth,
|
50 |
+
time_depth=time_depth,
|
51 |
+
input_dim=self.transformer_dim,
|
52 |
+
hidden_size=self.hidden_size,
|
53 |
+
output_dim=self.latent_dim + 2,
|
54 |
+
mlp_ratio=4.0,
|
55 |
+
add_space_attn=use_spaceatt,
|
56 |
+
)
|
57 |
+
|
58 |
+
self.fmap_norm = nn.LayerNorm(self.latent_dim)
|
59 |
+
self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
|
60 |
+
|
61 |
+
# A linear layer to update track feats at each iteration
|
62 |
+
self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
|
63 |
+
|
64 |
+
self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
|
65 |
+
|
66 |
+
if predict_conf:
|
67 |
+
self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
|
68 |
+
|
69 |
+
def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):
|
70 |
+
"""
|
71 |
+
query_points: B x N x 2, the number of batches, tracks, and xy
|
72 |
+
fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
|
73 |
+
note HH and WW is the size of feature maps instead of original images
|
74 |
+
"""
|
75 |
+
B, N, D = query_points.shape
|
76 |
+
B, S, C, HH, WW = fmaps.shape
|
77 |
+
|
78 |
+
assert D == 2, "Input points must be 2D coordinates"
|
79 |
+
|
80 |
+
# apply a layernorm to fmaps here
|
81 |
+
fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
|
82 |
+
fmaps = fmaps.permute(0, 1, 4, 2, 3)
|
83 |
+
|
84 |
+
# Scale the input query_points because we may downsample the images
|
85 |
+
# by down_ratio or self.stride
|
86 |
+
# e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
|
87 |
+
# its query_points should be query_points/4
|
88 |
+
if down_ratio > 1:
|
89 |
+
query_points = query_points / float(down_ratio)
|
90 |
+
|
91 |
+
query_points = query_points / float(self.stride)
|
92 |
+
|
93 |
+
# Init with coords as the query points
|
94 |
+
# It means the search will start from the position of query points at the reference frames
|
95 |
+
coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
|
96 |
+
|
97 |
+
# Sample/extract the features of the query points in the query frame
|
98 |
+
query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
|
99 |
+
|
100 |
+
# init track feats by query feats
|
101 |
+
track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
|
102 |
+
# back up the init coords
|
103 |
+
coords_backup = coords.clone()
|
104 |
+
|
105 |
+
fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
|
106 |
+
|
107 |
+
coord_preds = []
|
108 |
+
|
109 |
+
# Iterative Refinement
|
110 |
+
for _ in range(iters):
|
111 |
+
# Detach the gradients from the last iteration
|
112 |
+
# (in my experience, not very important for performance)
|
113 |
+
coords = coords.detach()
|
114 |
+
|
115 |
+
fcorrs = fcorr_fn.corr_sample(track_feats, coords)
|
116 |
+
|
117 |
+
corr_dim = fcorrs.shape[3]
|
118 |
+
fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
|
119 |
+
fcorrs_ = self.corr_mlp(fcorrs_)
|
120 |
+
|
121 |
+
# Movement of current coords relative to query points
|
122 |
+
flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
|
123 |
+
|
124 |
+
flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
|
125 |
+
|
126 |
+
# (In my trials, it is also okay to just add the flows_emb instead of concat)
|
127 |
+
flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
|
128 |
+
|
129 |
+
track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
|
130 |
+
|
131 |
+
# Concatenate them as the input for the transformers
|
132 |
+
transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
|
133 |
+
|
134 |
+
# 2D positional embed
|
135 |
+
pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
|
136 |
+
sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
|
137 |
+
|
138 |
+
sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
|
139 |
+
|
140 |
+
x = transformer_input + sampled_pos_emb
|
141 |
+
|
142 |
+
# Add the query ref token to the track feats
|
143 |
+
query_ref_token = torch.cat(
|
144 |
+
[self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
|
145 |
+
)
|
146 |
+
x = x + query_ref_token.to(x.device).to(x.dtype)
|
147 |
+
|
148 |
+
# B, N, S, C
|
149 |
+
x = rearrange(x, "(b n) s d -> b n s d", b=B)
|
150 |
+
|
151 |
+
# Compute the delta coordinates and delta track features
|
152 |
+
delta, _ = self.updateformer(x)
|
153 |
+
|
154 |
+
# BN, S, C
|
155 |
+
delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
|
156 |
+
delta_coords_ = delta[:, :, :2]
|
157 |
+
delta_feats_ = delta[:, :, 2:]
|
158 |
+
|
159 |
+
track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
|
160 |
+
delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
|
161 |
+
|
162 |
+
# Update the track features
|
163 |
+
track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
|
164 |
+
|
165 |
+
track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
|
166 |
+
|
167 |
+
# B x S x N x 2
|
168 |
+
coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
|
169 |
+
|
170 |
+
# Force coord0 as query
|
171 |
+
# because we assume the query points should not be changed
|
172 |
+
coords[:, 0] = coords_backup[:, 0]
|
173 |
+
|
174 |
+
# The predicted tracks are in the original image scale
|
175 |
+
if down_ratio > 1:
|
176 |
+
coord_preds.append(coords * self.stride * down_ratio)
|
177 |
+
else:
|
178 |
+
coord_preds.append(coords * self.stride)
|
179 |
+
|
180 |
+
# B, S, N
|
181 |
+
vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
|
182 |
+
if apply_sigmoid:
|
183 |
+
vis_e = torch.sigmoid(vis_e)
|
184 |
+
|
185 |
+
if self.predict_conf:
|
186 |
+
conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
|
187 |
+
if apply_sigmoid:
|
188 |
+
conf_e = torch.sigmoid(conf_e)
|
189 |
+
else:
|
190 |
+
conf_e = None
|
191 |
+
|
192 |
+
if return_feat:
|
193 |
+
return coord_preds, vis_e, track_feats, query_track_feat, conf_e
|
194 |
+
else:
|
195 |
+
return coord_preds, vis_e, conf_e
|
streamvggt/heads/track_modules/blocks.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from .utils import bilinear_sampler
|
7 |
+
from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
|
8 |
+
|
9 |
+
|
10 |
+
class EfficientUpdateFormer(nn.Module):
|
11 |
+
"""
|
12 |
+
Transformer model that updates track estimates.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
space_depth=6,
|
18 |
+
time_depth=6,
|
19 |
+
input_dim=320,
|
20 |
+
hidden_size=384,
|
21 |
+
num_heads=8,
|
22 |
+
output_dim=130,
|
23 |
+
mlp_ratio=4.0,
|
24 |
+
add_space_attn=True,
|
25 |
+
num_virtual_tracks=64,
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
|
29 |
+
self.out_channels = 2
|
30 |
+
self.num_heads = num_heads
|
31 |
+
self.hidden_size = hidden_size
|
32 |
+
self.add_space_attn = add_space_attn
|
33 |
+
|
34 |
+
# Add input LayerNorm before linear projection
|
35 |
+
self.input_norm = nn.LayerNorm(input_dim)
|
36 |
+
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
37 |
+
|
38 |
+
# Add output LayerNorm before final projection
|
39 |
+
self.output_norm = nn.LayerNorm(hidden_size)
|
40 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
41 |
+
self.num_virtual_tracks = num_virtual_tracks
|
42 |
+
|
43 |
+
if self.add_space_attn:
|
44 |
+
self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
|
45 |
+
else:
|
46 |
+
self.virual_tracks = None
|
47 |
+
|
48 |
+
self.time_blocks = nn.ModuleList(
|
49 |
+
[
|
50 |
+
AttnBlock(
|
51 |
+
hidden_size,
|
52 |
+
num_heads,
|
53 |
+
mlp_ratio=mlp_ratio,
|
54 |
+
attn_class=nn.MultiheadAttention,
|
55 |
+
)
|
56 |
+
for _ in range(time_depth)
|
57 |
+
]
|
58 |
+
)
|
59 |
+
|
60 |
+
if add_space_attn:
|
61 |
+
self.space_virtual_blocks = nn.ModuleList(
|
62 |
+
[
|
63 |
+
AttnBlock(
|
64 |
+
hidden_size,
|
65 |
+
num_heads,
|
66 |
+
mlp_ratio=mlp_ratio,
|
67 |
+
attn_class=nn.MultiheadAttention,
|
68 |
+
)
|
69 |
+
for _ in range(space_depth)
|
70 |
+
]
|
71 |
+
)
|
72 |
+
self.space_point2virtual_blocks = nn.ModuleList(
|
73 |
+
[CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
|
74 |
+
)
|
75 |
+
self.space_virtual2point_blocks = nn.ModuleList(
|
76 |
+
[CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
|
77 |
+
)
|
78 |
+
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
|
79 |
+
self.initialize_weights()
|
80 |
+
|
81 |
+
def initialize_weights(self):
|
82 |
+
def _basic_init(module):
|
83 |
+
if isinstance(module, nn.Linear):
|
84 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
85 |
+
if module.bias is not None:
|
86 |
+
nn.init.constant_(module.bias, 0)
|
87 |
+
torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
|
88 |
+
|
89 |
+
self.apply(_basic_init)
|
90 |
+
|
91 |
+
def forward(self, input_tensor, mask=None):
|
92 |
+
# Apply input LayerNorm
|
93 |
+
input_tensor = self.input_norm(input_tensor)
|
94 |
+
tokens = self.input_transform(input_tensor)
|
95 |
+
|
96 |
+
init_tokens = tokens
|
97 |
+
|
98 |
+
B, _, T, _ = tokens.shape
|
99 |
+
|
100 |
+
if self.add_space_attn:
|
101 |
+
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
|
102 |
+
tokens = torch.cat([tokens, virtual_tokens], dim=1)
|
103 |
+
|
104 |
+
_, N, _, _ = tokens.shape
|
105 |
+
|
106 |
+
j = 0
|
107 |
+
for i in range(len(self.time_blocks)):
|
108 |
+
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
|
109 |
+
|
110 |
+
time_tokens = self.time_blocks[i](time_tokens)
|
111 |
+
|
112 |
+
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
|
113 |
+
if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
|
114 |
+
space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
|
115 |
+
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
|
116 |
+
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
|
117 |
+
|
118 |
+
virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
|
119 |
+
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
|
120 |
+
point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
|
121 |
+
|
122 |
+
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
|
123 |
+
tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
|
124 |
+
j += 1
|
125 |
+
|
126 |
+
if self.add_space_attn:
|
127 |
+
tokens = tokens[:, : N - self.num_virtual_tracks]
|
128 |
+
|
129 |
+
tokens = tokens + init_tokens
|
130 |
+
|
131 |
+
# Apply output LayerNorm before final projection
|
132 |
+
tokens = self.output_norm(tokens)
|
133 |
+
flow = self.flow_head(tokens)
|
134 |
+
|
135 |
+
return flow, None
|
136 |
+
|
137 |
+
|
138 |
+
class CorrBlock:
|
139 |
+
def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
|
140 |
+
"""
|
141 |
+
Build a pyramid of feature maps from the input.
|
142 |
+
|
143 |
+
fmaps: Tensor (B, S, C, H, W)
|
144 |
+
num_levels: number of pyramid levels (each downsampled by factor 2)
|
145 |
+
radius: search radius for sampling correlation
|
146 |
+
multiple_track_feats: if True, split the target features per pyramid level
|
147 |
+
padding_mode: passed to grid_sample / bilinear_sampler
|
148 |
+
"""
|
149 |
+
B, S, C, H, W = fmaps.shape
|
150 |
+
self.S, self.C, self.H, self.W = S, C, H, W
|
151 |
+
self.num_levels = num_levels
|
152 |
+
self.radius = radius
|
153 |
+
self.padding_mode = padding_mode
|
154 |
+
self.multiple_track_feats = multiple_track_feats
|
155 |
+
|
156 |
+
# Build pyramid: each level is half the spatial resolution of the previous
|
157 |
+
self.fmaps_pyramid = [fmaps] # level 0 is full resolution
|
158 |
+
current_fmaps = fmaps
|
159 |
+
for i in range(num_levels - 1):
|
160 |
+
B, S, C, H, W = current_fmaps.shape
|
161 |
+
# Merge batch & sequence dimensions
|
162 |
+
current_fmaps = current_fmaps.reshape(B * S, C, H, W)
|
163 |
+
# Avg pool down by factor 2
|
164 |
+
current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
|
165 |
+
_, _, H_new, W_new = current_fmaps.shape
|
166 |
+
current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)
|
167 |
+
self.fmaps_pyramid.append(current_fmaps)
|
168 |
+
|
169 |
+
# Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.
|
170 |
+
# This grid is added to the (scaled) coordinate centroids.
|
171 |
+
r = self.radius
|
172 |
+
dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
|
173 |
+
dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
|
174 |
+
# delta: for every (dy,dx) displacement (i.e. Δx, Δy)
|
175 |
+
self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2)
|
176 |
+
|
177 |
+
def corr_sample(self, targets, coords):
|
178 |
+
"""
|
179 |
+
Instead of storing the entire correlation pyramid, we compute each level's correlation
|
180 |
+
volume, sample it immediately, then discard it. This saves GPU memory.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
targets: Tensor (B, S, N, C) — features for the current targets.
|
184 |
+
coords: Tensor (B, S, N, 2) — coordinates at full resolution.
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
|
188 |
+
"""
|
189 |
+
B, S, N, C = targets.shape
|
190 |
+
|
191 |
+
# If you have multiple track features, split them per level.
|
192 |
+
if self.multiple_track_feats:
|
193 |
+
targets_split = torch.split(targets, C // self.num_levels, dim=-1)
|
194 |
+
|
195 |
+
out_pyramid = []
|
196 |
+
for i, fmaps in enumerate(self.fmaps_pyramid):
|
197 |
+
# Get current spatial resolution H, W for this pyramid level.
|
198 |
+
B, S, C, H, W = fmaps.shape
|
199 |
+
# Reshape feature maps for correlation computation:
|
200 |
+
# fmap2s: (B, S, C, H*W)
|
201 |
+
fmap2s = fmaps.view(B, S, C, H * W)
|
202 |
+
# Choose appropriate target features.
|
203 |
+
fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C)
|
204 |
+
|
205 |
+
# Compute correlation directly
|
206 |
+
corrs = compute_corr_level(fmap1, fmap2s, C)
|
207 |
+
corrs = corrs.view(B, S, N, H, W)
|
208 |
+
|
209 |
+
# Prepare sampling grid:
|
210 |
+
# Scale down the coordinates for the current level.
|
211 |
+
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
|
212 |
+
# Make sure our precomputed delta grid is on the same device/dtype.
|
213 |
+
delta_lvl = self.delta.to(coords.device).to(coords.dtype)
|
214 |
+
# Now the grid for grid_sample is:
|
215 |
+
# coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid)
|
216 |
+
coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
|
217 |
+
|
218 |
+
# Sample from the correlation volume using bilinear interpolation.
|
219 |
+
# We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
|
220 |
+
corrs_sampled = bilinear_sampler(
|
221 |
+
corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode
|
222 |
+
)
|
223 |
+
# The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
|
224 |
+
corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2)
|
225 |
+
out_pyramid.append(corrs_sampled)
|
226 |
+
|
227 |
+
# Concatenate all levels along the last dimension.
|
228 |
+
out = torch.cat(out_pyramid, dim=-1).contiguous()
|
229 |
+
return out
|
230 |
+
|
231 |
+
|
232 |
+
def compute_corr_level(fmap1, fmap2s, C):
|
233 |
+
# fmap1: (B, S, N, C)
|
234 |
+
# fmap2s: (B, S, C, H*W)
|
235 |
+
corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W)
|
236 |
+
corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W)
|
237 |
+
return corrs / math.sqrt(C)
|
streamvggt/heads/track_modules/modules.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from functools import partial
|
5 |
+
from typing import Callable
|
6 |
+
import collections
|
7 |
+
from torch import Tensor
|
8 |
+
from itertools import repeat
|
9 |
+
|
10 |
+
|
11 |
+
# From PyTorch internals
|
12 |
+
def _ntuple(n):
|
13 |
+
def parse(x):
|
14 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
15 |
+
return tuple(x)
|
16 |
+
return tuple(repeat(x, n))
|
17 |
+
|
18 |
+
return parse
|
19 |
+
|
20 |
+
|
21 |
+
def exists(val):
|
22 |
+
return val is not None
|
23 |
+
|
24 |
+
|
25 |
+
def default(val, d):
|
26 |
+
return val if exists(val) else d
|
27 |
+
|
28 |
+
|
29 |
+
to_2tuple = _ntuple(2)
|
30 |
+
|
31 |
+
|
32 |
+
class ResidualBlock(nn.Module):
|
33 |
+
"""
|
34 |
+
ResidualBlock: construct a block of two conv layers with residual connections
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
|
38 |
+
super(ResidualBlock, self).__init__()
|
39 |
+
|
40 |
+
self.conv1 = nn.Conv2d(
|
41 |
+
in_planes,
|
42 |
+
planes,
|
43 |
+
kernel_size=kernel_size,
|
44 |
+
padding=1,
|
45 |
+
stride=stride,
|
46 |
+
padding_mode="zeros",
|
47 |
+
)
|
48 |
+
self.conv2 = nn.Conv2d(
|
49 |
+
planes,
|
50 |
+
planes,
|
51 |
+
kernel_size=kernel_size,
|
52 |
+
padding=1,
|
53 |
+
padding_mode="zeros",
|
54 |
+
)
|
55 |
+
self.relu = nn.ReLU(inplace=True)
|
56 |
+
|
57 |
+
num_groups = planes // 8
|
58 |
+
|
59 |
+
if norm_fn == "group":
|
60 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
61 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
62 |
+
if not stride == 1:
|
63 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
64 |
+
|
65 |
+
elif norm_fn == "batch":
|
66 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
67 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
68 |
+
if not stride == 1:
|
69 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
70 |
+
|
71 |
+
elif norm_fn == "instance":
|
72 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
73 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
74 |
+
if not stride == 1:
|
75 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
76 |
+
|
77 |
+
elif norm_fn == "none":
|
78 |
+
self.norm1 = nn.Sequential()
|
79 |
+
self.norm2 = nn.Sequential()
|
80 |
+
if not stride == 1:
|
81 |
+
self.norm3 = nn.Sequential()
|
82 |
+
else:
|
83 |
+
raise NotImplementedError
|
84 |
+
|
85 |
+
if stride == 1:
|
86 |
+
self.downsample = None
|
87 |
+
else:
|
88 |
+
self.downsample = nn.Sequential(
|
89 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
|
90 |
+
self.norm3,
|
91 |
+
)
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
y = x
|
95 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
96 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
97 |
+
|
98 |
+
if self.downsample is not None:
|
99 |
+
x = self.downsample(x)
|
100 |
+
|
101 |
+
return self.relu(x + y)
|
102 |
+
|
103 |
+
|
104 |
+
class Mlp(nn.Module):
|
105 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
106 |
+
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
in_features,
|
110 |
+
hidden_features=None,
|
111 |
+
out_features=None,
|
112 |
+
act_layer=nn.GELU,
|
113 |
+
norm_layer=None,
|
114 |
+
bias=True,
|
115 |
+
drop=0.0,
|
116 |
+
use_conv=False,
|
117 |
+
):
|
118 |
+
super().__init__()
|
119 |
+
out_features = out_features or in_features
|
120 |
+
hidden_features = hidden_features or in_features
|
121 |
+
bias = to_2tuple(bias)
|
122 |
+
drop_probs = to_2tuple(drop)
|
123 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
124 |
+
|
125 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
126 |
+
self.act = act_layer()
|
127 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
128 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
129 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
130 |
+
|
131 |
+
def forward(self, x):
|
132 |
+
x = self.fc1(x)
|
133 |
+
x = self.act(x)
|
134 |
+
x = self.drop1(x)
|
135 |
+
x = self.fc2(x)
|
136 |
+
x = self.drop2(x)
|
137 |
+
return x
|
138 |
+
|
139 |
+
|
140 |
+
class AttnBlock(nn.Module):
|
141 |
+
def __init__(
|
142 |
+
self,
|
143 |
+
hidden_size,
|
144 |
+
num_heads,
|
145 |
+
attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
|
146 |
+
mlp_ratio=4.0,
|
147 |
+
**block_kwargs
|
148 |
+
):
|
149 |
+
"""
|
150 |
+
Self attention block
|
151 |
+
"""
|
152 |
+
super().__init__()
|
153 |
+
|
154 |
+
self.norm1 = nn.LayerNorm(hidden_size)
|
155 |
+
self.norm2 = nn.LayerNorm(hidden_size)
|
156 |
+
|
157 |
+
self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
|
158 |
+
|
159 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
160 |
+
|
161 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
|
162 |
+
|
163 |
+
def forward(self, x, mask=None):
|
164 |
+
# Prepare the mask for PyTorch's attention (it expects a different format)
|
165 |
+
# attn_mask = mask if mask is not None else None
|
166 |
+
# Normalize before attention
|
167 |
+
x = self.norm1(x)
|
168 |
+
|
169 |
+
# PyTorch's MultiheadAttention returns attn_output, attn_output_weights
|
170 |
+
# attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
|
171 |
+
|
172 |
+
attn_output, _ = self.attn(x, x, x)
|
173 |
+
|
174 |
+
# Add & Norm
|
175 |
+
x = x + attn_output
|
176 |
+
x = x + self.mlp(self.norm2(x))
|
177 |
+
return x
|
178 |
+
|
179 |
+
|
180 |
+
class CrossAttnBlock(nn.Module):
|
181 |
+
def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
|
182 |
+
"""
|
183 |
+
Cross attention block
|
184 |
+
"""
|
185 |
+
super().__init__()
|
186 |
+
|
187 |
+
self.norm1 = nn.LayerNorm(hidden_size)
|
188 |
+
self.norm_context = nn.LayerNorm(hidden_size)
|
189 |
+
self.norm2 = nn.LayerNorm(hidden_size)
|
190 |
+
|
191 |
+
self.cross_attn = nn.MultiheadAttention(
|
192 |
+
embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
|
193 |
+
)
|
194 |
+
|
195 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
196 |
+
|
197 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
|
198 |
+
|
199 |
+
def forward(self, x, context, mask=None):
|
200 |
+
# Normalize inputs
|
201 |
+
x = self.norm1(x)
|
202 |
+
context = self.norm_context(context)
|
203 |
+
|
204 |
+
# Apply cross attention
|
205 |
+
# Note: nn.MultiheadAttention returns attn_output, attn_output_weights
|
206 |
+
attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
|
207 |
+
|
208 |
+
# Add & Norm
|
209 |
+
x = x + attn_output
|
210 |
+
x = x + self.mlp(self.norm2(x))
|
211 |
+
return x
|
streamvggt/heads/track_modules/utils.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from typing import Optional, Tuple, Union
|
6 |
+
|
7 |
+
|
8 |
+
def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
|
9 |
+
"""
|
10 |
+
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
|
11 |
+
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
|
12 |
+
Args:
|
13 |
+
- embed_dim: The embedding dimension.
|
14 |
+
- grid_size: The grid size.
|
15 |
+
Returns:
|
16 |
+
- pos_embed: The generated 2D positional embedding.
|
17 |
+
"""
|
18 |
+
if isinstance(grid_size, tuple):
|
19 |
+
grid_size_h, grid_size_w = grid_size
|
20 |
+
else:
|
21 |
+
grid_size_h = grid_size_w = grid_size
|
22 |
+
grid_h = torch.arange(grid_size_h, dtype=torch.float)
|
23 |
+
grid_w = torch.arange(grid_size_w, dtype=torch.float)
|
24 |
+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
25 |
+
grid = torch.stack(grid, dim=0)
|
26 |
+
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
27 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
28 |
+
if return_grid:
|
29 |
+
return (
|
30 |
+
pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2),
|
31 |
+
grid,
|
32 |
+
)
|
33 |
+
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
|
34 |
+
|
35 |
+
|
36 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
|
37 |
+
"""
|
38 |
+
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
- embed_dim: The embedding dimension.
|
42 |
+
- grid: The grid to generate the embedding from.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
- emb: The generated 2D positional embedding.
|
46 |
+
"""
|
47 |
+
assert embed_dim % 2 == 0
|
48 |
+
|
49 |
+
# use half of dimensions to encode grid_h
|
50 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
51 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
52 |
+
|
53 |
+
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
|
54 |
+
return emb
|
55 |
+
|
56 |
+
|
57 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
|
58 |
+
"""
|
59 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
- embed_dim: The embedding dimension.
|
63 |
+
- pos: The position to generate the embedding from.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
- emb: The generated 1D positional embedding.
|
67 |
+
"""
|
68 |
+
assert embed_dim % 2 == 0
|
69 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
70 |
+
omega /= embed_dim / 2.0
|
71 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
72 |
+
|
73 |
+
pos = pos.reshape(-1) # (M,)
|
74 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
75 |
+
|
76 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
77 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
78 |
+
|
79 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
80 |
+
return emb[None].float()
|
81 |
+
|
82 |
+
|
83 |
+
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
|
84 |
+
"""
|
85 |
+
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
- xy: The coordinates to generate the embedding from.
|
89 |
+
- C: The size of the embedding.
|
90 |
+
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
- pe: The generated 2D positional embedding.
|
94 |
+
"""
|
95 |
+
B, N, D = xy.shape
|
96 |
+
assert D == 2
|
97 |
+
|
98 |
+
x = xy[:, :, 0:1]
|
99 |
+
y = xy[:, :, 1:2]
|
100 |
+
div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
|
101 |
+
|
102 |
+
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
103 |
+
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
104 |
+
|
105 |
+
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
106 |
+
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
107 |
+
|
108 |
+
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
109 |
+
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
110 |
+
|
111 |
+
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
|
112 |
+
if cat_coords:
|
113 |
+
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
|
114 |
+
return pe
|
115 |
+
|
116 |
+
|
117 |
+
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
|
118 |
+
r"""Sample a tensor using bilinear interpolation
|
119 |
+
|
120 |
+
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
121 |
+
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
122 |
+
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
123 |
+
convention.
|
124 |
+
|
125 |
+
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
126 |
+
:math:`B` is the batch size, :math:`C` is the number of channels,
|
127 |
+
:math:`H` is the height of the image, and :math:`W` is the width of the
|
128 |
+
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
129 |
+
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
130 |
+
|
131 |
+
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
132 |
+
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
133 |
+
that in this case the order of the components is slightly different
|
134 |
+
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
135 |
+
|
136 |
+
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
137 |
+
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
138 |
+
left-most image pixel :math:`W-1` to the center of the right-most
|
139 |
+
pixel.
|
140 |
+
|
141 |
+
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
142 |
+
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
143 |
+
the left-most pixel :math:`W` to the right edge of the right-most
|
144 |
+
pixel.
|
145 |
+
|
146 |
+
Similar conventions apply to the :math:`y` for the range
|
147 |
+
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
148 |
+
:math:`[0,T-1]` and :math:`[0,T]`.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
input (Tensor): batch of input images.
|
152 |
+
coords (Tensor): batch of coordinates.
|
153 |
+
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
154 |
+
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
Tensor: sampled points.
|
158 |
+
"""
|
159 |
+
coords = coords.detach().clone()
|
160 |
+
############################################################
|
161 |
+
# IMPORTANT:
|
162 |
+
coords = coords.to(input.device).to(input.dtype)
|
163 |
+
############################################################
|
164 |
+
|
165 |
+
sizes = input.shape[2:]
|
166 |
+
|
167 |
+
assert len(sizes) in [2, 3]
|
168 |
+
|
169 |
+
if len(sizes) == 3:
|
170 |
+
# t x y -> x y t to match dimensions T H W in grid_sample
|
171 |
+
coords = coords[..., [1, 2, 0]]
|
172 |
+
|
173 |
+
if align_corners:
|
174 |
+
scale = torch.tensor(
|
175 |
+
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype
|
176 |
+
)
|
177 |
+
else:
|
178 |
+
scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)
|
179 |
+
|
180 |
+
coords.mul_(scale) # coords = coords * scale
|
181 |
+
coords.sub_(1) # coords = coords - 1
|
182 |
+
|
183 |
+
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
|
184 |
+
|
185 |
+
|
186 |
+
def sample_features4d(input, coords):
|
187 |
+
r"""Sample spatial features
|
188 |
+
|
189 |
+
`sample_features4d(input, coords)` samples the spatial features
|
190 |
+
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
191 |
+
|
192 |
+
The field is sampled at coordinates :attr:`coords` using bilinear
|
193 |
+
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
194 |
+
2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
195 |
+
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
196 |
+
|
197 |
+
The output tensor has one feature per point, and has shape :math:`(B,
|
198 |
+
R, C)`.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
input (Tensor): spatial features.
|
202 |
+
coords (Tensor): points.
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
Tensor: sampled features.
|
206 |
+
"""
|
207 |
+
|
208 |
+
B, _, _, _ = input.shape
|
209 |
+
|
210 |
+
# B R 2 -> B R 1 2
|
211 |
+
coords = coords.unsqueeze(2)
|
212 |
+
|
213 |
+
# B C R 1
|
214 |
+
feats = bilinear_sampler(input, coords)
|
215 |
+
|
216 |
+
return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
|
vggt/heads/__pycache__/utils.cpython-310.pyc → streamvggt/heads/utils.py
RENAMED
Binary files a/vggt/heads/__pycache__/utils.cpython-310.pyc and b/streamvggt/heads/utils.py differ
|
|
streamvggt/layers/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .mlp import Mlp
|
2 |
+
from .patch_embed import PatchEmbed
|
3 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
4 |
+
from .block import NestedTensorBlock
|
5 |
+
from .attention import MemEffAttention
|
streamvggt/layers/attention.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import Tensor
|
7 |
+
from torch import nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from typing import Union, Tuple, Dict, Optional
|
10 |
+
|
11 |
+
from einops import rearrange
|
12 |
+
|
13 |
+
XFORMERS_AVAILABLE = False
|
14 |
+
|
15 |
+
|
16 |
+
class Attention(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
dim: int,
|
20 |
+
num_heads: int = 8,
|
21 |
+
qkv_bias: bool = True,
|
22 |
+
proj_bias: bool = True,
|
23 |
+
attn_drop: float = 0.0,
|
24 |
+
proj_drop: float = 0.0,
|
25 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
26 |
+
qk_norm: bool = False,
|
27 |
+
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
28 |
+
rope=None,
|
29 |
+
) -> None:
|
30 |
+
super().__init__()
|
31 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
32 |
+
self.num_heads = num_heads
|
33 |
+
self.head_dim = dim // num_heads
|
34 |
+
self.scale = self.head_dim**-0.5
|
35 |
+
self.fused_attn = fused_attn
|
36 |
+
|
37 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
38 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
39 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
40 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
41 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
42 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
43 |
+
self.rope = rope
|
44 |
+
|
45 |
+
def forward(self,
|
46 |
+
x: torch.Tensor,
|
47 |
+
pos=None,
|
48 |
+
attn_mask=None,
|
49 |
+
past_key_values=None,
|
50 |
+
use_cache=False
|
51 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple]]:
|
52 |
+
B, N, C = x.shape
|
53 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
54 |
+
q, k, v = qkv.unbind(0)
|
55 |
+
|
56 |
+
pos_k = pos
|
57 |
+
if use_cache:
|
58 |
+
k = k.unsqueeze(2)
|
59 |
+
v = v.unsqueeze(2)
|
60 |
+
if past_key_values is not None:
|
61 |
+
past_k, past_v = past_key_values
|
62 |
+
k = torch.cat([past_k, k], dim=2)
|
63 |
+
v = torch.cat([past_v, v], dim=2)
|
64 |
+
|
65 |
+
new_kv = (k, v)
|
66 |
+
a, b, c, d, e = k.shape
|
67 |
+
k = k.reshape(a, b, c*d, e)
|
68 |
+
v = v.reshape(a, b, c*d, e)
|
69 |
+
if pos_k is not None:
|
70 |
+
#print(pos_k.shape)
|
71 |
+
pos_k = pos_k.repeat(1, c, 1)
|
72 |
+
#print(pos_k.shape)
|
73 |
+
|
74 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
75 |
+
|
76 |
+
if self.rope is not None:
|
77 |
+
q = self.rope(q, pos)
|
78 |
+
k = self.rope(k, pos_k)
|
79 |
+
|
80 |
+
if self.fused_attn:
|
81 |
+
x = F.scaled_dot_product_attention(
|
82 |
+
q,
|
83 |
+
k,
|
84 |
+
v,
|
85 |
+
attn_mask=attn_mask,
|
86 |
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
87 |
+
)
|
88 |
+
|
89 |
+
else:
|
90 |
+
q = q * self.scale
|
91 |
+
attn = q @ k.transpose(-2, -1)
|
92 |
+
|
93 |
+
# Mask
|
94 |
+
if attn_mask is not None:
|
95 |
+
assert attn_mask.shape[-2:] == (N, N), f"Expected mask shape [..., {N}, {N}], got {attn_mask.shape}"
|
96 |
+
attn = attn + attn_mask
|
97 |
+
|
98 |
+
attn = attn.softmax(dim=-1)
|
99 |
+
attn = self.attn_drop(attn)
|
100 |
+
x = attn @ v
|
101 |
+
|
102 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
103 |
+
x = self.proj(x)
|
104 |
+
x = self.proj_drop(x)
|
105 |
+
if use_cache:
|
106 |
+
return x, new_kv
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
class MemEffAttention(Attention):
|
111 |
+
def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
|
112 |
+
assert pos is None
|
113 |
+
if not XFORMERS_AVAILABLE:
|
114 |
+
if attn_bias is not None:
|
115 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
116 |
+
return super().forward(x)
|
117 |
+
|
118 |
+
B, N, C = x.shape
|
119 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
120 |
+
|
121 |
+
q, k, v = unbind(qkv, 2)
|
122 |
+
|
123 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
124 |
+
x = x.reshape([B, N, C])
|
125 |
+
|
126 |
+
x = self.proj(x)
|
127 |
+
x = self.proj_drop(x)
|
128 |
+
|
129 |
+
return x
|
streamvggt/layers/block.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from typing import Callable, List, Any, Tuple, Dict, Union
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn, Tensor
|
8 |
+
|
9 |
+
from .attention import Attention
|
10 |
+
from .drop_path import DropPath
|
11 |
+
from .layer_scale import LayerScale
|
12 |
+
from .mlp import Mlp
|
13 |
+
|
14 |
+
|
15 |
+
XFORMERS_AVAILABLE = False
|
16 |
+
|
17 |
+
|
18 |
+
class Block(nn.Module):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
dim: int,
|
22 |
+
num_heads: int,
|
23 |
+
mlp_ratio: float = 4.0,
|
24 |
+
qkv_bias: bool = True,
|
25 |
+
proj_bias: bool = True,
|
26 |
+
ffn_bias: bool = True,
|
27 |
+
drop: float = 0.0,
|
28 |
+
attn_drop: float = 0.0,
|
29 |
+
init_values=None,
|
30 |
+
drop_path: float = 0.0,
|
31 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
32 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
33 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
34 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
35 |
+
qk_norm: bool = False,
|
36 |
+
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
37 |
+
rope=None,
|
38 |
+
) -> None:
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.norm1 = norm_layer(dim)
|
42 |
+
|
43 |
+
self.attn = attn_class(
|
44 |
+
dim,
|
45 |
+
num_heads=num_heads,
|
46 |
+
qkv_bias=qkv_bias,
|
47 |
+
proj_bias=proj_bias,
|
48 |
+
attn_drop=attn_drop,
|
49 |
+
proj_drop=drop,
|
50 |
+
qk_norm=qk_norm,
|
51 |
+
fused_attn=fused_attn,
|
52 |
+
rope=rope,
|
53 |
+
)
|
54 |
+
|
55 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
56 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
57 |
+
|
58 |
+
self.norm2 = norm_layer(dim)
|
59 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
60 |
+
self.mlp = ffn_layer(
|
61 |
+
in_features=dim,
|
62 |
+
hidden_features=mlp_hidden_dim,
|
63 |
+
act_layer=act_layer,
|
64 |
+
drop=drop,
|
65 |
+
bias=ffn_bias,
|
66 |
+
)
|
67 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
68 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
69 |
+
|
70 |
+
self.sample_drop_ratio = drop_path
|
71 |
+
|
72 |
+
def forward(self, x: Tensor, pos=None, attn_mask=None, past_key_values=None, use_cache=False) -> Union[Tensor, Tuple[Tensor, Dict]]:
|
73 |
+
|
74 |
+
def attn_residual_func(x: Tensor, pos=None, attn_mask=None, past_key_values=None, use_cache=False) -> Union[Tensor, Tuple[Tensor, Dict]]:
|
75 |
+
if use_cache:
|
76 |
+
output, new_kv = self.attn(self.norm1(x), pos=pos, past_key_values=past_key_values, use_cache=True)
|
77 |
+
return self.ls1(output), new_kv
|
78 |
+
else:
|
79 |
+
if attn_mask is not None:
|
80 |
+
return self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask))
|
81 |
+
else:
|
82 |
+
return self.ls1(self.attn(self.norm1(x), pos=pos))
|
83 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
84 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
85 |
+
|
86 |
+
if use_cache:
|
87 |
+
attn_output, new_kv = attn_residual_func(x, pos=pos, past_key_values=past_key_values, use_cache=True)
|
88 |
+
x = x + attn_output
|
89 |
+
x = x + ffn_residual_func(x)
|
90 |
+
return x, new_kv
|
91 |
+
|
92 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
93 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
94 |
+
x = drop_add_residual_stochastic_depth(
|
95 |
+
x,
|
96 |
+
pos=pos,
|
97 |
+
residual_func=attn_residual_func,
|
98 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
99 |
+
)
|
100 |
+
x = drop_add_residual_stochastic_depth(
|
101 |
+
x,
|
102 |
+
residual_func=ffn_residual_func,
|
103 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
104 |
+
)
|
105 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
106 |
+
x = x + self.drop_path1(attn_residual_func(x, pos=pos, attn_mask=attn_mask))
|
107 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
108 |
+
else:
|
109 |
+
x = x + attn_residual_func(x, pos=pos, attn_mask=attn_mask)
|
110 |
+
x = x + ffn_residual_func(x)
|
111 |
+
return x
|
112 |
+
|
113 |
+
|
114 |
+
def drop_add_residual_stochastic_depth(
|
115 |
+
x: Tensor,
|
116 |
+
residual_func: Callable[[Tensor], Tensor],
|
117 |
+
sample_drop_ratio: float = 0.0,
|
118 |
+
pos=None,
|
119 |
+
) -> Tensor:
|
120 |
+
# 1) extract subset using permutation
|
121 |
+
b, n, d = x.shape
|
122 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
123 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
124 |
+
x_subset = x[brange]
|
125 |
+
|
126 |
+
# 2) apply residual_func to get residual
|
127 |
+
if pos is not None:
|
128 |
+
# if necessary, apply rope to the subset
|
129 |
+
pos = pos[brange]
|
130 |
+
residual = residual_func(x_subset, pos=pos)
|
131 |
+
else:
|
132 |
+
residual = residual_func(x_subset)
|
133 |
+
|
134 |
+
x_flat = x.flatten(1)
|
135 |
+
residual = residual.flatten(1)
|
136 |
+
|
137 |
+
residual_scale_factor = b / sample_subset_size
|
138 |
+
|
139 |
+
# 3) add the residual
|
140 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
141 |
+
return x_plus_residual.view_as(x)
|
142 |
+
|
143 |
+
|
144 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
145 |
+
b, n, d = x.shape
|
146 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
147 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
148 |
+
residual_scale_factor = b / sample_subset_size
|
149 |
+
return brange, residual_scale_factor
|
150 |
+
|
151 |
+
|
152 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
153 |
+
if scaling_vector is None:
|
154 |
+
x_flat = x.flatten(1)
|
155 |
+
residual = residual.flatten(1)
|
156 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
157 |
+
else:
|
158 |
+
x_plus_residual = scaled_index_add(
|
159 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
160 |
+
)
|
161 |
+
return x_plus_residual
|
162 |
+
|
163 |
+
|
164 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
165 |
+
|
166 |
+
|
167 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
168 |
+
"""
|
169 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
170 |
+
"""
|
171 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
172 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
173 |
+
if all_shapes not in attn_bias_cache.keys():
|
174 |
+
seqlens = []
|
175 |
+
for b, x in zip(batch_sizes, x_list):
|
176 |
+
for _ in range(b):
|
177 |
+
seqlens.append(x.shape[1])
|
178 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
179 |
+
attn_bias._batch_sizes = batch_sizes
|
180 |
+
attn_bias_cache[all_shapes] = attn_bias
|
181 |
+
|
182 |
+
if branges is not None:
|
183 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
184 |
+
else:
|
185 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
186 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
187 |
+
|
188 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
189 |
+
|
190 |
+
|
191 |
+
def drop_add_residual_stochastic_depth_list(
|
192 |
+
x_list: List[Tensor],
|
193 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
194 |
+
sample_drop_ratio: float = 0.0,
|
195 |
+
scaling_vector=None,
|
196 |
+
) -> Tensor:
|
197 |
+
# 1) generate random set of indices for dropping samples in the batch
|
198 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
199 |
+
branges = [s[0] for s in branges_scales]
|
200 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
201 |
+
|
202 |
+
# 2) get attention bias and index+concat the tensors
|
203 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
204 |
+
|
205 |
+
# 3) apply residual_func to get residual, and split the result
|
206 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
207 |
+
|
208 |
+
outputs = []
|
209 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
210 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
211 |
+
return outputs
|
212 |
+
|
213 |
+
|
214 |
+
class NestedTensorBlock(Block):
|
215 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
216 |
+
"""
|
217 |
+
x_list contains a list of tensors to nest together and run
|
218 |
+
"""
|
219 |
+
assert isinstance(self.attn, MemEffAttention)
|
220 |
+
|
221 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
222 |
+
|
223 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
224 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
225 |
+
|
226 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
227 |
+
return self.mlp(self.norm2(x))
|
228 |
+
|
229 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
230 |
+
x_list,
|
231 |
+
residual_func=attn_residual_func,
|
232 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
233 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
234 |
+
)
|
235 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
236 |
+
x_list,
|
237 |
+
residual_func=ffn_residual_func,
|
238 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
239 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
240 |
+
)
|
241 |
+
return x_list
|
242 |
+
else:
|
243 |
+
|
244 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
245 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
246 |
+
|
247 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
248 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
249 |
+
|
250 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
251 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
252 |
+
x = x + ffn_residual_func(x)
|
253 |
+
return attn_bias.split(x)
|
254 |
+
|
255 |
+
def forward(self, x_or_x_list):
|
256 |
+
if isinstance(x_or_x_list, Tensor):
|
257 |
+
return super().forward(x_or_x_list)
|
258 |
+
elif isinstance(x_or_x_list, list):
|
259 |
+
if not XFORMERS_AVAILABLE:
|
260 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
261 |
+
return self.forward_nested(x_or_x_list)
|
262 |
+
else:
|
263 |
+
raise AssertionError
|
streamvggt/layers/drop_path.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
|
4 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
5 |
+
if drop_prob == 0.0 or not training:
|
6 |
+
return x
|
7 |
+
keep_prob = 1 - drop_prob
|
8 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
9 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
10 |
+
if keep_prob > 0.0:
|
11 |
+
random_tensor.div_(keep_prob)
|
12 |
+
output = x * random_tensor
|
13 |
+
return output
|
14 |
+
|
15 |
+
|
16 |
+
class DropPath(nn.Module):
|
17 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
18 |
+
|
19 |
+
def __init__(self, drop_prob=None):
|
20 |
+
super(DropPath, self).__init__()
|
21 |
+
self.drop_prob = drop_prob
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
return drop_path(x, self.drop_prob, self.training)
|
streamvggt/layers/layer_scale.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
|
8 |
+
class LayerScale(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
dim: int,
|
12 |
+
init_values: Union[float, Tensor] = 1e-5,
|
13 |
+
inplace: bool = False,
|
14 |
+
) -> None:
|
15 |
+
super().__init__()
|
16 |
+
self.inplace = inplace
|
17 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
18 |
+
|
19 |
+
def forward(self, x: Tensor) -> Tensor:
|
20 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
streamvggt/layers/mlp.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Optional
|
2 |
+
|
3 |
+
from torch import Tensor, nn
|
4 |
+
|
5 |
+
|
6 |
+
class Mlp(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
in_features: int,
|
10 |
+
hidden_features: Optional[int] = None,
|
11 |
+
out_features: Optional[int] = None,
|
12 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
13 |
+
drop: float = 0.0,
|
14 |
+
bias: bool = True,
|
15 |
+
) -> None:
|
16 |
+
super().__init__()
|
17 |
+
out_features = out_features or in_features
|
18 |
+
hidden_features = hidden_features or in_features
|
19 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
20 |
+
self.act = act_layer()
|
21 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
22 |
+
self.drop = nn.Dropout(drop)
|
23 |
+
|
24 |
+
def forward(self, x: Tensor) -> Tensor:
|
25 |
+
x = self.fc1(x)
|
26 |
+
x = self.act(x)
|
27 |
+
x = self.drop(x)
|
28 |
+
x = self.fc2(x)
|
29 |
+
x = self.drop(x)
|
30 |
+
return x
|
streamvggt/layers/patch_embed.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Optional, Tuple, Union
|
2 |
+
|
3 |
+
from torch import Tensor
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
def make_2tuple(x):
|
8 |
+
if isinstance(x, tuple):
|
9 |
+
assert len(x) == 2
|
10 |
+
return x
|
11 |
+
|
12 |
+
assert isinstance(x, int)
|
13 |
+
return (x, x)
|
14 |
+
|
15 |
+
|
16 |
+
class PatchEmbed(nn.Module):
|
17 |
+
"""
|
18 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
19 |
+
|
20 |
+
Args:
|
21 |
+
img_size: Image size.
|
22 |
+
patch_size: Patch token size.
|
23 |
+
in_chans: Number of input image channels.
|
24 |
+
embed_dim: Number of linear projection output channels.
|
25 |
+
norm_layer: Normalization layer.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
31 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
32 |
+
in_chans: int = 3,
|
33 |
+
embed_dim: int = 768,
|
34 |
+
norm_layer: Optional[Callable] = None,
|
35 |
+
flatten_embedding: bool = True,
|
36 |
+
) -> None:
|
37 |
+
super().__init__()
|
38 |
+
|
39 |
+
image_HW = make_2tuple(img_size)
|
40 |
+
patch_HW = make_2tuple(patch_size)
|
41 |
+
patch_grid_size = (
|
42 |
+
image_HW[0] // patch_HW[0],
|
43 |
+
image_HW[1] // patch_HW[1],
|
44 |
+
)
|
45 |
+
|
46 |
+
self.img_size = image_HW
|
47 |
+
self.patch_size = patch_HW
|
48 |
+
self.patches_resolution = patch_grid_size
|
49 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
50 |
+
|
51 |
+
self.in_chans = in_chans
|
52 |
+
self.embed_dim = embed_dim
|
53 |
+
|
54 |
+
self.flatten_embedding = flatten_embedding
|
55 |
+
|
56 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
57 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
58 |
+
|
59 |
+
def forward(self, x: Tensor) -> Tensor:
|
60 |
+
_, _, H, W = x.shape
|
61 |
+
patch_H, patch_W = self.patch_size
|
62 |
+
|
63 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
64 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
65 |
+
|
66 |
+
x = self.proj(x) # B C H W
|
67 |
+
H, W = x.size(2), x.size(3)
|
68 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
69 |
+
x = self.norm(x)
|
70 |
+
if not self.flatten_embedding:
|
71 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
72 |
+
return x
|
73 |
+
|
74 |
+
def flops(self) -> float:
|
75 |
+
Ho, Wo = self.patches_resolution
|
76 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
77 |
+
if self.norm is not None:
|
78 |
+
flops += Ho * Wo * self.embed_dim
|
79 |
+
return flops
|
vggt/layers/__pycache__/rope.cpython-310.pyc → streamvggt/layers/rope.py
RENAMED
Binary files a/vggt/layers/__pycache__/rope.cpython-310.pyc and b/streamvggt/layers/rope.py differ
|
|
streamvggt/layers/swiglu_ffn.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Callable, Optional
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
from torch import Tensor, nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class SwiGLUFFN(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
in_features: int,
|
13 |
+
hidden_features: Optional[int] = None,
|
14 |
+
out_features: Optional[int] = None,
|
15 |
+
act_layer: Callable[..., nn.Module] = None,
|
16 |
+
drop: float = 0.0,
|
17 |
+
bias: bool = True,
|
18 |
+
) -> None:
|
19 |
+
super().__init__()
|
20 |
+
out_features = out_features or in_features
|
21 |
+
hidden_features = hidden_features or in_features
|
22 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
23 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
24 |
+
|
25 |
+
def forward(self, x: Tensor) -> Tensor:
|
26 |
+
x12 = self.w12(x)
|
27 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
28 |
+
hidden = F.silu(x1) * x2
|
29 |
+
return self.w3(hidden)
|
30 |
+
|
31 |
+
|
32 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
33 |
+
# try:
|
34 |
+
# if XFORMERS_ENABLED:
|
35 |
+
# from xformers.ops import SwiGLU
|
36 |
+
|
37 |
+
# XFORMERS_AVAILABLE = True
|
38 |
+
# warnings.warn("xFormers is available (SwiGLU)")
|
39 |
+
# else:
|
40 |
+
# warnings.warn("xFormers is disabled (SwiGLU)")
|
41 |
+
# raise ImportError
|
42 |
+
# except ImportError:
|
43 |
+
SwiGLU = SwiGLUFFN
|
44 |
+
XFORMERS_AVAILABLE = False
|
45 |
+
|
46 |
+
# warnings.warn("xFormers is not available (SwiGLU)")
|
47 |
+
|
48 |
+
|
49 |
+
class SwiGLUFFNFused(SwiGLU):
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
in_features: int,
|
53 |
+
hidden_features: Optional[int] = None,
|
54 |
+
out_features: Optional[int] = None,
|
55 |
+
act_layer: Callable[..., nn.Module] = None,
|
56 |
+
drop: float = 0.0,
|
57 |
+
bias: bool = True,
|
58 |
+
) -> None:
|
59 |
+
out_features = out_features or in_features
|
60 |
+
hidden_features = hidden_features or in_features
|
61 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
62 |
+
super().__init__(
|
63 |
+
in_features=in_features,
|
64 |
+
hidden_features=hidden_features,
|
65 |
+
out_features=out_features,
|
66 |
+
bias=bias,
|
67 |
+
)
|
streamvggt/layers/vision_transformer.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import math
|
3 |
+
import logging
|
4 |
+
from typing import Sequence, Tuple, Union, Callable
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.utils.checkpoint import checkpoint
|
9 |
+
from torch.nn.init import trunc_normal_
|
10 |
+
from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
11 |
+
|
12 |
+
logger = logging.getLogger("dinov2")
|
13 |
+
|
14 |
+
|
15 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
16 |
+
if not depth_first and include_root:
|
17 |
+
fn(module=module, name=name)
|
18 |
+
for child_name, child_module in module.named_children():
|
19 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
20 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
21 |
+
if depth_first and include_root:
|
22 |
+
fn(module=module, name=name)
|
23 |
+
return module
|
24 |
+
|
25 |
+
|
26 |
+
class BlockChunk(nn.ModuleList):
|
27 |
+
def forward(self, x):
|
28 |
+
for b in self:
|
29 |
+
x = b(x)
|
30 |
+
return x
|
31 |
+
|
32 |
+
|
33 |
+
class DinoVisionTransformer(nn.Module):
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
img_size=224,
|
37 |
+
patch_size=16,
|
38 |
+
in_chans=3,
|
39 |
+
embed_dim=768,
|
40 |
+
depth=12,
|
41 |
+
num_heads=12,
|
42 |
+
mlp_ratio=4.0,
|
43 |
+
qkv_bias=True,
|
44 |
+
ffn_bias=True,
|
45 |
+
proj_bias=True,
|
46 |
+
drop_path_rate=0.0,
|
47 |
+
drop_path_uniform=False,
|
48 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
49 |
+
embed_layer=PatchEmbed,
|
50 |
+
act_layer=nn.GELU,
|
51 |
+
block_fn=Block,
|
52 |
+
ffn_layer="mlp",
|
53 |
+
block_chunks=1,
|
54 |
+
num_register_tokens=0,
|
55 |
+
interpolate_antialias=False,
|
56 |
+
interpolate_offset=0.1,
|
57 |
+
qk_norm=False,
|
58 |
+
):
|
59 |
+
"""
|
60 |
+
Args:
|
61 |
+
img_size (int, tuple): input image size
|
62 |
+
patch_size (int, tuple): patch size
|
63 |
+
in_chans (int): number of input channels
|
64 |
+
embed_dim (int): embedding dimension
|
65 |
+
depth (int): depth of transformer
|
66 |
+
num_heads (int): number of attention heads
|
67 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
68 |
+
qkv_bias (bool): enable bias for qkv if True
|
69 |
+
proj_bias (bool): enable bias for proj in attn if True
|
70 |
+
ffn_bias (bool): enable bias for ffn if True
|
71 |
+
drop_path_rate (float): stochastic depth rate
|
72 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
73 |
+
weight_init (str): weight init scheme
|
74 |
+
init_values (float): layer-scale init values
|
75 |
+
embed_layer (nn.Module): patch embedding layer
|
76 |
+
act_layer (nn.Module): MLP activation layer
|
77 |
+
block_fn (nn.Module): transformer block class
|
78 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
79 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
80 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
81 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
82 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
83 |
+
"""
|
84 |
+
super().__init__()
|
85 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
86 |
+
|
87 |
+
# tricky but makes it work
|
88 |
+
self.use_checkpoint = False
|
89 |
+
#
|
90 |
+
|
91 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
92 |
+
self.num_tokens = 1
|
93 |
+
self.n_blocks = depth
|
94 |
+
self.num_heads = num_heads
|
95 |
+
self.patch_size = patch_size
|
96 |
+
self.num_register_tokens = num_register_tokens
|
97 |
+
self.interpolate_antialias = interpolate_antialias
|
98 |
+
self.interpolate_offset = interpolate_offset
|
99 |
+
|
100 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
101 |
+
num_patches = self.patch_embed.num_patches
|
102 |
+
|
103 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
104 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
105 |
+
assert num_register_tokens >= 0
|
106 |
+
self.register_tokens = (
|
107 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
108 |
+
)
|
109 |
+
|
110 |
+
if drop_path_uniform is True:
|
111 |
+
dpr = [drop_path_rate] * depth
|
112 |
+
else:
|
113 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
114 |
+
|
115 |
+
if ffn_layer == "mlp":
|
116 |
+
logger.info("using MLP layer as FFN")
|
117 |
+
ffn_layer = Mlp
|
118 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
119 |
+
logger.info("using SwiGLU layer as FFN")
|
120 |
+
ffn_layer = SwiGLUFFNFused
|
121 |
+
elif ffn_layer == "identity":
|
122 |
+
logger.info("using Identity layer as FFN")
|
123 |
+
|
124 |
+
def f(*args, **kwargs):
|
125 |
+
return nn.Identity()
|
126 |
+
|
127 |
+
ffn_layer = f
|
128 |
+
else:
|
129 |
+
raise NotImplementedError
|
130 |
+
|
131 |
+
blocks_list = [
|
132 |
+
block_fn(
|
133 |
+
dim=embed_dim,
|
134 |
+
num_heads=num_heads,
|
135 |
+
mlp_ratio=mlp_ratio,
|
136 |
+
qkv_bias=qkv_bias,
|
137 |
+
proj_bias=proj_bias,
|
138 |
+
ffn_bias=ffn_bias,
|
139 |
+
drop_path=dpr[i],
|
140 |
+
norm_layer=norm_layer,
|
141 |
+
act_layer=act_layer,
|
142 |
+
ffn_layer=ffn_layer,
|
143 |
+
init_values=init_values,
|
144 |
+
qk_norm=qk_norm,
|
145 |
+
)
|
146 |
+
for i in range(depth)
|
147 |
+
]
|
148 |
+
if block_chunks > 0:
|
149 |
+
self.chunked_blocks = True
|
150 |
+
chunked_blocks = []
|
151 |
+
chunksize = depth // block_chunks
|
152 |
+
for i in range(0, depth, chunksize):
|
153 |
+
# this is to keep the block index consistent if we chunk the block list
|
154 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
155 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
156 |
+
else:
|
157 |
+
self.chunked_blocks = False
|
158 |
+
self.blocks = nn.ModuleList(blocks_list)
|
159 |
+
|
160 |
+
self.norm = norm_layer(embed_dim)
|
161 |
+
self.head = nn.Identity()
|
162 |
+
|
163 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
164 |
+
|
165 |
+
self.init_weights()
|
166 |
+
|
167 |
+
def init_weights(self):
|
168 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
169 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
170 |
+
if self.register_tokens is not None:
|
171 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
172 |
+
named_apply(init_weights_vit_timm, self)
|
173 |
+
|
174 |
+
def interpolate_pos_encoding(self, x, w, h):
|
175 |
+
previous_dtype = x.dtype
|
176 |
+
npatch = x.shape[1] - 1
|
177 |
+
N = self.pos_embed.shape[1] - 1
|
178 |
+
if npatch == N and w == h:
|
179 |
+
return self.pos_embed
|
180 |
+
pos_embed = self.pos_embed.float()
|
181 |
+
class_pos_embed = pos_embed[:, 0]
|
182 |
+
patch_pos_embed = pos_embed[:, 1:]
|
183 |
+
dim = x.shape[-1]
|
184 |
+
w0 = w // self.patch_size
|
185 |
+
h0 = h // self.patch_size
|
186 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
187 |
+
assert N == M * M
|
188 |
+
kwargs = {}
|
189 |
+
if self.interpolate_offset:
|
190 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
191 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
192 |
+
sx = float(w0 + self.interpolate_offset) / M
|
193 |
+
sy = float(h0 + self.interpolate_offset) / M
|
194 |
+
kwargs["scale_factor"] = (sx, sy)
|
195 |
+
else:
|
196 |
+
# Simply specify an output size instead of a scale factor
|
197 |
+
kwargs["size"] = (w0, h0)
|
198 |
+
patch_pos_embed = nn.functional.interpolate(
|
199 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
200 |
+
mode="bicubic",
|
201 |
+
antialias=self.interpolate_antialias,
|
202 |
+
**kwargs,
|
203 |
+
)
|
204 |
+
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
205 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
206 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
207 |
+
|
208 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
209 |
+
B, nc, w, h = x.shape
|
210 |
+
x = self.patch_embed(x)
|
211 |
+
if masks is not None:
|
212 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
213 |
+
|
214 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
215 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
216 |
+
|
217 |
+
if self.register_tokens is not None:
|
218 |
+
x = torch.cat(
|
219 |
+
(
|
220 |
+
x[:, :1],
|
221 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
222 |
+
x[:, 1:],
|
223 |
+
),
|
224 |
+
dim=1,
|
225 |
+
)
|
226 |
+
|
227 |
+
return x
|
228 |
+
|
229 |
+
def forward_features_list(self, x_list, masks_list):
|
230 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
231 |
+
|
232 |
+
for blk in self.blocks:
|
233 |
+
if self.use_checkpoint:
|
234 |
+
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
|
235 |
+
else:
|
236 |
+
x = blk(x)
|
237 |
+
|
238 |
+
all_x = x
|
239 |
+
output = []
|
240 |
+
for x, masks in zip(all_x, masks_list):
|
241 |
+
x_norm = self.norm(x)
|
242 |
+
output.append(
|
243 |
+
{
|
244 |
+
"x_norm_clstoken": x_norm[:, 0],
|
245 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
246 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
247 |
+
"x_prenorm": x,
|
248 |
+
"masks": masks,
|
249 |
+
}
|
250 |
+
)
|
251 |
+
return output
|
252 |
+
|
253 |
+
def forward_features(self, x, masks=None):
|
254 |
+
if isinstance(x, list):
|
255 |
+
return self.forward_features_list(x, masks)
|
256 |
+
|
257 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
258 |
+
|
259 |
+
for blk in self.blocks:
|
260 |
+
if self.use_checkpoint:
|
261 |
+
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
|
262 |
+
else:
|
263 |
+
x = blk(x)
|
264 |
+
|
265 |
+
x_norm = self.norm(x)
|
266 |
+
return {
|
267 |
+
"x_norm_clstoken": x_norm[:, 0],
|
268 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
269 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
270 |
+
"x_prenorm": x,
|
271 |
+
"masks": masks,
|
272 |
+
}
|
273 |
+
|
274 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
275 |
+
x = self.prepare_tokens_with_masks(x)
|
276 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
277 |
+
output, total_block_len = [], len(self.blocks)
|
278 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
279 |
+
for i, blk in enumerate(self.blocks):
|
280 |
+
x = blk(x)
|
281 |
+
if i in blocks_to_take:
|
282 |
+
output.append(x)
|
283 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
284 |
+
return output
|
285 |
+
|
286 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
287 |
+
x = self.prepare_tokens_with_masks(x)
|
288 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
289 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
290 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
291 |
+
for block_chunk in self.blocks:
|
292 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
293 |
+
x = blk(x)
|
294 |
+
if i in blocks_to_take:
|
295 |
+
output.append(x)
|
296 |
+
i += 1
|
297 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
298 |
+
return output
|
299 |
+
|
300 |
+
def get_intermediate_layers(
|
301 |
+
self,
|
302 |
+
x: torch.Tensor,
|
303 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
304 |
+
reshape: bool = False,
|
305 |
+
return_class_token: bool = False,
|
306 |
+
norm=True,
|
307 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
308 |
+
if self.chunked_blocks:
|
309 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
310 |
+
else:
|
311 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
312 |
+
if norm:
|
313 |
+
outputs = [self.norm(out) for out in outputs]
|
314 |
+
class_tokens = [out[:, 0] for out in outputs]
|
315 |
+
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
316 |
+
if reshape:
|
317 |
+
B, _, w, h = x.shape
|
318 |
+
outputs = [
|
319 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
320 |
+
for out in outputs
|
321 |
+
]
|
322 |
+
if return_class_token:
|
323 |
+
return tuple(zip(outputs, class_tokens))
|
324 |
+
return tuple(outputs)
|
325 |
+
|
326 |
+
def forward(self, *args, is_training=True, **kwargs):
|
327 |
+
ret = self.forward_features(*args, **kwargs)
|
328 |
+
if is_training:
|
329 |
+
return ret
|
330 |
+
else:
|
331 |
+
return self.head(ret["x_norm_clstoken"])
|
332 |
+
|
333 |
+
|
334 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
335 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
336 |
+
if isinstance(module, nn.Linear):
|
337 |
+
trunc_normal_(module.weight, std=0.02)
|
338 |
+
if module.bias is not None:
|
339 |
+
nn.init.zeros_(module.bias)
|
340 |
+
|
341 |
+
|
342 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
343 |
+
model = DinoVisionTransformer(
|
344 |
+
patch_size=patch_size,
|
345 |
+
embed_dim=384,
|
346 |
+
depth=12,
|
347 |
+
num_heads=6,
|
348 |
+
mlp_ratio=4,
|
349 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
350 |
+
num_register_tokens=num_register_tokens,
|
351 |
+
**kwargs,
|
352 |
+
)
|
353 |
+
return model
|
354 |
+
|
355 |
+
|
356 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
357 |
+
model = DinoVisionTransformer(
|
358 |
+
patch_size=patch_size,
|
359 |
+
embed_dim=768,
|
360 |
+
depth=12,
|
361 |
+
num_heads=12,
|
362 |
+
mlp_ratio=4,
|
363 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
364 |
+
num_register_tokens=num_register_tokens,
|
365 |
+
**kwargs,
|
366 |
+
)
|
367 |
+
return model
|
368 |
+
|
369 |
+
|
370 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
371 |
+
model = DinoVisionTransformer(
|
372 |
+
patch_size=patch_size,
|
373 |
+
embed_dim=1024,
|
374 |
+
depth=24,
|
375 |
+
num_heads=16,
|
376 |
+
mlp_ratio=4,
|
377 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
378 |
+
num_register_tokens=num_register_tokens,
|
379 |
+
**kwargs,
|
380 |
+
)
|
381 |
+
return model
|
382 |
+
|
383 |
+
|
384 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
385 |
+
"""
|
386 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
387 |
+
"""
|
388 |
+
model = DinoVisionTransformer(
|
389 |
+
patch_size=patch_size,
|
390 |
+
embed_dim=1536,
|
391 |
+
depth=40,
|
392 |
+
num_heads=24,
|
393 |
+
mlp_ratio=4,
|
394 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
395 |
+
num_register_tokens=num_register_tokens,
|
396 |
+
**kwargs,
|
397 |
+
)
|
398 |
+
return model
|
streamvggt/models/aggregator.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from typing import Optional, Tuple, Union, List, Dict, Any
|
12 |
+
|
13 |
+
from streamvggt.layers import PatchEmbed
|
14 |
+
from streamvggt.layers.block import Block
|
15 |
+
from streamvggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
|
16 |
+
from streamvggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
_RESNET_MEAN = [0.485, 0.456, 0.406]
|
21 |
+
_RESNET_STD = [0.229, 0.224, 0.225]
|
22 |
+
|
23 |
+
|
24 |
+
class Aggregator(nn.Module):
|
25 |
+
"""
|
26 |
+
The Aggregator applies alternating-attention over input frames,
|
27 |
+
as described in VGGT: Visual Geometry Grounded Transformer.
|
28 |
+
|
29 |
+
|
30 |
+
Args:
|
31 |
+
img_size (int): Image size in pixels.
|
32 |
+
patch_size (int): Size of each patch for PatchEmbed.
|
33 |
+
embed_dim (int): Dimension of the token embeddings.
|
34 |
+
depth (int): Number of blocks.
|
35 |
+
num_heads (int): Number of attention heads.
|
36 |
+
mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
|
37 |
+
num_register_tokens (int): Number of register tokens.
|
38 |
+
block_fn (nn.Module): The block type used for attention (Block by default).
|
39 |
+
qkv_bias (bool): Whether to include bias in QKV projections.
|
40 |
+
proj_bias (bool): Whether to include bias in the output projection.
|
41 |
+
ffn_bias (bool): Whether to include bias in MLP layers.
|
42 |
+
patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
|
43 |
+
aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
|
44 |
+
aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
|
45 |
+
qk_norm (bool): Whether to apply QK normalization.
|
46 |
+
rope_freq (int): Base frequency for rotary embedding. -1 to disable.
|
47 |
+
init_values (float): Init scale for layer scale.
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
img_size=518,
|
53 |
+
patch_size=14,
|
54 |
+
embed_dim=1024,
|
55 |
+
depth=24,
|
56 |
+
num_heads=16,
|
57 |
+
mlp_ratio=4.0,
|
58 |
+
num_register_tokens=4,
|
59 |
+
block_fn=Block,
|
60 |
+
qkv_bias=True,
|
61 |
+
proj_bias=True,
|
62 |
+
ffn_bias=True,
|
63 |
+
patch_embed="dinov2_vitl14_reg",
|
64 |
+
aa_order=["frame", "global"],
|
65 |
+
aa_block_size=1,
|
66 |
+
qk_norm=True,
|
67 |
+
rope_freq=100,
|
68 |
+
init_values=0.01,
|
69 |
+
):
|
70 |
+
super().__init__()
|
71 |
+
|
72 |
+
self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
|
73 |
+
|
74 |
+
# Initialize rotary position embedding if frequency > 0
|
75 |
+
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
|
76 |
+
self.position_getter = PositionGetter() if self.rope is not None else None
|
77 |
+
|
78 |
+
self.frame_blocks = nn.ModuleList(
|
79 |
+
[
|
80 |
+
block_fn(
|
81 |
+
dim=embed_dim,
|
82 |
+
num_heads=num_heads,
|
83 |
+
mlp_ratio=mlp_ratio,
|
84 |
+
qkv_bias=qkv_bias,
|
85 |
+
proj_bias=proj_bias,
|
86 |
+
ffn_bias=ffn_bias,
|
87 |
+
init_values=init_values,
|
88 |
+
qk_norm=qk_norm,
|
89 |
+
rope=self.rope,
|
90 |
+
)
|
91 |
+
for _ in range(depth)
|
92 |
+
]
|
93 |
+
)
|
94 |
+
|
95 |
+
self.global_blocks = nn.ModuleList(
|
96 |
+
[
|
97 |
+
block_fn(
|
98 |
+
dim=embed_dim,
|
99 |
+
num_heads=num_heads,
|
100 |
+
mlp_ratio=mlp_ratio,
|
101 |
+
qkv_bias=qkv_bias,
|
102 |
+
proj_bias=proj_bias,
|
103 |
+
ffn_bias=ffn_bias,
|
104 |
+
init_values=init_values,
|
105 |
+
qk_norm=qk_norm,
|
106 |
+
rope=self.rope,
|
107 |
+
)
|
108 |
+
for _ in range(depth)
|
109 |
+
]
|
110 |
+
)
|
111 |
+
|
112 |
+
self.depth = depth
|
113 |
+
self.aa_order = aa_order
|
114 |
+
self.patch_size = patch_size
|
115 |
+
self.aa_block_size = aa_block_size
|
116 |
+
|
117 |
+
# Validate that depth is divisible by aa_block_size
|
118 |
+
if self.depth % self.aa_block_size != 0:
|
119 |
+
raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
|
120 |
+
|
121 |
+
self.aa_block_num = self.depth // self.aa_block_size
|
122 |
+
|
123 |
+
# Note: We have two camera tokens, one for the first frame and one for the rest
|
124 |
+
# The same applies for register tokens
|
125 |
+
self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
|
126 |
+
self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
|
127 |
+
|
128 |
+
# The patch tokens start after the camera and register tokens
|
129 |
+
self.patch_start_idx = 1 + num_register_tokens
|
130 |
+
|
131 |
+
# Initialize parameters with small values
|
132 |
+
nn.init.normal_(self.camera_token, std=1e-6)
|
133 |
+
nn.init.normal_(self.register_token, std=1e-6)
|
134 |
+
|
135 |
+
# Register normalization constants as buffers
|
136 |
+
for name, value in (
|
137 |
+
("_resnet_mean", _RESNET_MEAN),
|
138 |
+
("_resnet_std", _RESNET_STD),
|
139 |
+
):
|
140 |
+
self.register_buffer(
|
141 |
+
name,
|
142 |
+
torch.FloatTensor(value).reshape(1, 1, 3, 1, 1),
|
143 |
+
persistent=False,
|
144 |
+
)
|
145 |
+
|
146 |
+
|
147 |
+
def __build_patch_embed__(
|
148 |
+
self,
|
149 |
+
patch_embed,
|
150 |
+
img_size,
|
151 |
+
patch_size,
|
152 |
+
num_register_tokens,
|
153 |
+
interpolate_antialias=True,
|
154 |
+
interpolate_offset=0.0,
|
155 |
+
block_chunks=0,
|
156 |
+
init_values=1.0,
|
157 |
+
embed_dim=1024,
|
158 |
+
):
|
159 |
+
"""
|
160 |
+
Build the patch embed layer. If 'conv', we use a
|
161 |
+
simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
|
162 |
+
"""
|
163 |
+
|
164 |
+
if "conv" in patch_embed:
|
165 |
+
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
|
166 |
+
else:
|
167 |
+
vit_models = {
|
168 |
+
"dinov2_vitl14_reg": vit_large,
|
169 |
+
"dinov2_vitb14_reg": vit_base,
|
170 |
+
"dinov2_vits14_reg": vit_small,
|
171 |
+
"dinov2_vitg2_reg": vit_giant2,
|
172 |
+
}
|
173 |
+
|
174 |
+
self.patch_embed = vit_models[patch_embed](
|
175 |
+
img_size=img_size,
|
176 |
+
patch_size=patch_size,
|
177 |
+
num_register_tokens=num_register_tokens,
|
178 |
+
interpolate_antialias=interpolate_antialias,
|
179 |
+
interpolate_offset=interpolate_offset,
|
180 |
+
block_chunks=block_chunks,
|
181 |
+
init_values=init_values,
|
182 |
+
)
|
183 |
+
|
184 |
+
# Disable gradient updates for mask token
|
185 |
+
if hasattr(self.patch_embed, "mask_token"):
|
186 |
+
self.patch_embed.mask_token.requires_grad_(False)
|
187 |
+
|
188 |
+
def forward(
|
189 |
+
self,
|
190 |
+
images: torch.Tensor,
|
191 |
+
past_key_values=None,
|
192 |
+
use_cache=False,
|
193 |
+
past_frame_idx=0
|
194 |
+
) -> Tuple[List[torch.Tensor], int]:
|
195 |
+
"""
|
196 |
+
Args:
|
197 |
+
images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
198 |
+
B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
(list[torch.Tensor], int):
|
202 |
+
The list of outputs from the attention blocks,
|
203 |
+
and the patch_start_idx indicating where patch tokens begin.
|
204 |
+
"""
|
205 |
+
B, S, C_in, H, W = images.shape
|
206 |
+
|
207 |
+
if use_cache and past_key_values[0] is not None:
|
208 |
+
_, _, S_true, _, _ = past_key_values[0][0].shape
|
209 |
+
S_true += 1
|
210 |
+
else:
|
211 |
+
S_true = S
|
212 |
+
|
213 |
+
if use_cache and S > 1:
|
214 |
+
print(f"Use KV cache expects S=1, got S={S}")
|
215 |
+
|
216 |
+
if C_in != 3:
|
217 |
+
raise ValueError(f"Expected 3 input channels, got {C_in}")
|
218 |
+
|
219 |
+
# Normalize images and reshape for patch embed
|
220 |
+
images = (images - self._resnet_mean.to(images.device)) / self._resnet_std.to(images.device)
|
221 |
+
|
222 |
+
# Reshape to [B*S, C, H, W] for patch embedding
|
223 |
+
images = images.reshape(B * S, C_in, H, W)
|
224 |
+
patch_tokens = self.patch_embed(images)
|
225 |
+
|
226 |
+
if isinstance(patch_tokens, dict):
|
227 |
+
patch_tokens = patch_tokens["x_norm_patchtokens"]
|
228 |
+
|
229 |
+
_, P, C = patch_tokens.shape
|
230 |
+
|
231 |
+
if use_cache:
|
232 |
+
camera_token_full = slice_expand_and_flatten(self.camera_token, B, S_true)
|
233 |
+
camera_token = camera_token_full[-1:, :, :]
|
234 |
+
|
235 |
+
register_token_full = slice_expand_and_flatten(self.register_token, B, S_true)
|
236 |
+
register_token = register_token_full[-1:, :, :]
|
237 |
+
else:
|
238 |
+
camera_token = slice_expand_and_flatten(self.camera_token, B, S)
|
239 |
+
register_token = slice_expand_and_flatten(self.register_token, B, S)
|
240 |
+
# Concatenate special tokens with patch tokens
|
241 |
+
tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)
|
242 |
+
|
243 |
+
pos = None
|
244 |
+
if self.rope is not None:
|
245 |
+
pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
|
246 |
+
|
247 |
+
if self.patch_start_idx > 0:
|
248 |
+
# do not use position embedding for special tokens (camera and register tokens)
|
249 |
+
# so set pos to 0 for the special tokens
|
250 |
+
pos = pos + 1
|
251 |
+
pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
|
252 |
+
pos = torch.cat([pos_special, pos], dim=1)
|
253 |
+
|
254 |
+
# update P because we added special tokens
|
255 |
+
_, P, C = tokens.shape
|
256 |
+
|
257 |
+
frame_idx = 0
|
258 |
+
global_idx = 0
|
259 |
+
output_list = []
|
260 |
+
|
261 |
+
for _ in range(self.aa_block_num):
|
262 |
+
for attn_type in self.aa_order:
|
263 |
+
if attn_type == "frame":
|
264 |
+
tokens, frame_idx, frame_intermediates = self._process_frame_attention(
|
265 |
+
tokens, B, S, P, C, frame_idx, pos=pos
|
266 |
+
)
|
267 |
+
elif attn_type == "global":
|
268 |
+
if use_cache:
|
269 |
+
if past_key_values[global_idx] is not None:
|
270 |
+
k, v = past_key_values[global_idx]
|
271 |
+
tokens, global_idx, global_intermediates, new_kv = self._process_global_attention(
|
272 |
+
tokens, B, S, P, C, global_idx, pos=pos,
|
273 |
+
past_key_values_block=past_key_values[global_idx] if past_key_values[global_idx] is not None else None,
|
274 |
+
use_cache=True,
|
275 |
+
past_frame_idx=past_frame_idx
|
276 |
+
)
|
277 |
+
past_key_values[global_idx - 1] = new_kv
|
278 |
+
else:
|
279 |
+
tokens, global_idx, global_intermediates = self._process_global_attention(
|
280 |
+
tokens, B, S, P, C, global_idx, pos=pos
|
281 |
+
)
|
282 |
+
else:
|
283 |
+
raise ValueError(f"Unknown attention type: {attn_type}")
|
284 |
+
for i in range(len(frame_intermediates)):
|
285 |
+
# concat frame and global intermediates, [B x S x P x 2C]
|
286 |
+
concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
|
287 |
+
output_list.append(concat_inter)
|
288 |
+
|
289 |
+
del concat_inter
|
290 |
+
del frame_intermediates
|
291 |
+
del global_intermediates
|
292 |
+
if use_cache:
|
293 |
+
return output_list, self.patch_start_idx, past_key_values
|
294 |
+
return output_list, self.patch_start_idx
|
295 |
+
|
296 |
+
def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
|
297 |
+
"""
|
298 |
+
Process frame attention blocks. We keep tokens in shape (B*S, P, C).
|
299 |
+
"""
|
300 |
+
# If needed, reshape tokens or positions:
|
301 |
+
if tokens.shape != (B * S, P, C):
|
302 |
+
tokens = tokens.reshape(B, S, P, C).reshape(B * S, P, C)
|
303 |
+
|
304 |
+
if pos is not None and pos.shape != (B * S, P, 2):
|
305 |
+
pos = pos.reshape(B, S, P, 2).reshape(B * S, P, 2)
|
306 |
+
|
307 |
+
intermediates = []
|
308 |
+
|
309 |
+
# by default, self.aa_block_size=1, which processes one block at a time
|
310 |
+
for _ in range(self.aa_block_size):
|
311 |
+
tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
|
312 |
+
frame_idx += 1
|
313 |
+
intermediates.append(tokens.reshape(B, S, P, C))
|
314 |
+
|
315 |
+
return tokens, frame_idx, intermediates
|
316 |
+
|
317 |
+
def _process_global_attention(
|
318 |
+
self,
|
319 |
+
tokens,
|
320 |
+
B,
|
321 |
+
S,
|
322 |
+
P,
|
323 |
+
C,
|
324 |
+
global_idx,
|
325 |
+
pos=None,
|
326 |
+
past_key_values_block=None,
|
327 |
+
use_cache=False,
|
328 |
+
past_frame_idx=0
|
329 |
+
) -> Union[Tuple[torch.Tensor, int, List[torch.Tensor]], Tuple[torch.Tensor, int, List[torch.Tensor], List]]:
|
330 |
+
"""
|
331 |
+
Process global attention blocks. We keep tokens in shape (B, S*P, C).
|
332 |
+
"""
|
333 |
+
|
334 |
+
if tokens.shape != (B, S * P, C):
|
335 |
+
tokens = tokens.reshape(B, S, P, C).reshape(B, S * P, C)
|
336 |
+
|
337 |
+
if pos is not None and pos.shape != (B, S * P, 2):
|
338 |
+
pos = pos.reshape(B, S, P, 2).reshape(B, S * P, 2)
|
339 |
+
|
340 |
+
intermediates = []
|
341 |
+
|
342 |
+
for _ in range(self.aa_block_size):
|
343 |
+
if not use_cache:
|
344 |
+
L = S * P
|
345 |
+
frame_ids = torch.arange(L, device=tokens.device) // P # [0,0,...,1,1,...,S-1]
|
346 |
+
future_frame = frame_ids.unsqueeze(1) < frame_ids.unsqueeze(0)
|
347 |
+
attn_mask = future_frame.to(tokens.dtype) * torch.finfo(tokens.dtype).min
|
348 |
+
else:
|
349 |
+
attn_mask = None
|
350 |
+
|
351 |
+
if use_cache:
|
352 |
+
tokens, block_kv = self.global_blocks[global_idx](
|
353 |
+
tokens,
|
354 |
+
pos=pos,
|
355 |
+
attn_mask=attn_mask,
|
356 |
+
past_key_values=past_key_values_block,
|
357 |
+
use_cache=True
|
358 |
+
)
|
359 |
+
else:
|
360 |
+
tokens = self.global_blocks[global_idx](tokens, pos=pos, attn_mask=attn_mask)
|
361 |
+
global_idx += 1
|
362 |
+
intermediates.append(tokens.reshape(B, S, P, C))
|
363 |
+
|
364 |
+
# if self.use_causal_global:
|
365 |
+
# del attn_mask
|
366 |
+
if use_cache:
|
367 |
+
return tokens, global_idx, intermediates, block_kv
|
368 |
+
return tokens, global_idx, intermediates
|
369 |
+
|
370 |
+
|
371 |
+
def slice_expand_and_flatten(token_tensor, B, S):
|
372 |
+
"""
|
373 |
+
Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
|
374 |
+
1) Uses the first position (index=0) for the first frame only
|
375 |
+
2) Uses the second position (index=1) for all remaining frames (S-1 frames)
|
376 |
+
3) Expands both to match batch size B
|
377 |
+
4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
|
378 |
+
followed by (S-1) second-position tokens
|
379 |
+
5) Flattens to (B*S, X, C) for processing
|
380 |
+
|
381 |
+
Returns:
|
382 |
+
torch.Tensor: Processed tokens with shape (B*S, X, C)
|
383 |
+
"""
|
384 |
+
|
385 |
+
# Slice out the "query" tokens => shape (1, 1, ...)
|
386 |
+
query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
|
387 |
+
# Slice out the "other" tokens => shape (1, S-1, ...)
|
388 |
+
others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
|
389 |
+
# Concatenate => shape (B, S, ...)
|
390 |
+
combined = torch.cat([query, others], dim=1)
|
391 |
+
|
392 |
+
# Finally flatten => shape (B*S, ...)
|
393 |
+
combined = combined.reshape(B * S, *combined.shape[2:])
|
394 |
+
return combined
|
streamvggt/models/streamvggt.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from huggingface_hub import PyTorchModelHubMixin # used for model hub
|
4 |
+
|
5 |
+
from streamvggt.models.aggregator import Aggregator
|
6 |
+
from streamvggt.heads.camera_head import CameraHead
|
7 |
+
from streamvggt.heads.dpt_head import DPTHead
|
8 |
+
from streamvggt.heads.track_head import TrackHead
|
9 |
+
from transformers.file_utils import ModelOutput
|
10 |
+
from typing import Optional, Tuple, List, Any
|
11 |
+
from dataclasses import dataclass
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class StreamVGGTOutput(ModelOutput):
|
15 |
+
ress: Optional[List[dict]] = None
|
16 |
+
views: Optional[torch.Tensor] = None
|
17 |
+
|
18 |
+
class StreamVGGT(nn.Module, PyTorchModelHubMixin):
|
19 |
+
def __init__(self, img_size=518, patch_size=14, embed_dim=1024):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
|
23 |
+
self.camera_head = CameraHead(dim_in=2 * embed_dim)
|
24 |
+
self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
|
25 |
+
self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1")
|
26 |
+
self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
def forward(
|
31 |
+
self,
|
32 |
+
views,
|
33 |
+
query_points: torch.Tensor = None,
|
34 |
+
history_info: Optional[dict] = None,
|
35 |
+
past_key_values=None,
|
36 |
+
use_cache=False,
|
37 |
+
past_frame_idx=0
|
38 |
+
):
|
39 |
+
images = torch.stack(
|
40 |
+
[view["img"] for view in views], dim=0
|
41 |
+
).permute(1, 0, 2, 3, 4) # B S C H W
|
42 |
+
|
43 |
+
# If without batch dimension, add it
|
44 |
+
if len(images.shape) == 4:
|
45 |
+
images = images.unsqueeze(0)
|
46 |
+
if query_points is not None and len(query_points.shape) == 2:
|
47 |
+
query_points = query_points.unsqueeze(0)
|
48 |
+
|
49 |
+
if history_info is None:
|
50 |
+
history_info = {"token": None}
|
51 |
+
|
52 |
+
aggregated_tokens_list, patch_start_idx = self.aggregator(images)
|
53 |
+
predictions = {}
|
54 |
+
|
55 |
+
with torch.cuda.amp.autocast(enabled=False):
|
56 |
+
if self.camera_head is not None:
|
57 |
+
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
58 |
+
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
|
59 |
+
|
60 |
+
if self.depth_head is not None:
|
61 |
+
depth, depth_conf = self.depth_head(
|
62 |
+
aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
|
63 |
+
)
|
64 |
+
predictions["depth"] = depth
|
65 |
+
predictions["depth_conf"] = depth_conf
|
66 |
+
|
67 |
+
if self.point_head is not None:
|
68 |
+
pts3d, pts3d_conf = self.point_head(
|
69 |
+
aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
|
70 |
+
)
|
71 |
+
predictions["world_points"] = pts3d
|
72 |
+
predictions["world_points_conf"] = pts3d_conf
|
73 |
+
|
74 |
+
if self.track_head is not None and query_points is not None:
|
75 |
+
track_list, vis, conf = self.track_head(
|
76 |
+
aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points
|
77 |
+
)
|
78 |
+
predictions["track"] = track_list[-1] # track of the last iteration
|
79 |
+
predictions["vis"] = vis
|
80 |
+
predictions["conf"] = conf
|
81 |
+
predictions["images"] = images
|
82 |
+
|
83 |
+
B, S = images.shape[:2]
|
84 |
+
ress = []
|
85 |
+
for s in range(S):
|
86 |
+
res = {
|
87 |
+
'pts3d_in_other_view': predictions['world_points'][:, s], # [B, H, W, 3]
|
88 |
+
'conf': predictions['world_points_conf'][:, s], # [B, H, W]
|
89 |
+
|
90 |
+
'depth': predictions['depth'][:, s], # [B, H, W, 1]
|
91 |
+
'depth_conf': predictions['depth_conf'][:, s], # [B, H, W]
|
92 |
+
'camera_pose': predictions['pose_enc'][:, s, :], # [B, 9]
|
93 |
+
|
94 |
+
**({'valid_mask': views[s]["valid_mask"]}
|
95 |
+
if 'valid_mask' in views[s] else {}), # [B, H, W]
|
96 |
+
|
97 |
+
**({'track': predictions['track'][:, s], # [B, N, 2]
|
98 |
+
'vis': predictions['vis'][:, s], # [B, N]
|
99 |
+
'track_conf': predictions['conf'][:, s]}
|
100 |
+
if 'track' in predictions else {})
|
101 |
+
}
|
102 |
+
ress.append(res)
|
103 |
+
return StreamVGGTOutput(ress=ress, views=views) # [S] [B, C, H, W]
|
104 |
+
|
105 |
+
def inference(self, frames, query_points: torch.Tensor = None, past_key_values=None):
|
106 |
+
past_key_values = [None] * self.aggregator.depth
|
107 |
+
past_key_values_camera = [None] * self.camera_head.trunk_depth
|
108 |
+
|
109 |
+
all_ress = []
|
110 |
+
processed_frames = []
|
111 |
+
|
112 |
+
for i, frame in enumerate(frames):
|
113 |
+
images = frame["img"].unsqueeze(0)
|
114 |
+
aggregator_output = self.aggregator(
|
115 |
+
images,
|
116 |
+
past_key_values=past_key_values,
|
117 |
+
use_cache=True,
|
118 |
+
past_frame_idx=i
|
119 |
+
)
|
120 |
+
|
121 |
+
if isinstance(aggregator_output, tuple) and len(aggregator_output) == 3:
|
122 |
+
aggregated_tokens, patch_start_idx, past_key_values = aggregator_output
|
123 |
+
else:
|
124 |
+
aggregated_tokens, patch_start_idx = aggregator_output
|
125 |
+
|
126 |
+
with torch.cuda.amp.autocast(enabled=False):
|
127 |
+
if self.camera_head is not None:
|
128 |
+
pose_enc, past_key_values_camera = self.camera_head(aggregated_tokens, past_key_values_camera=past_key_values_camera, use_cache=True)
|
129 |
+
pose_enc = pose_enc[-1]
|
130 |
+
camera_pose = pose_enc[:, 0, :]
|
131 |
+
|
132 |
+
if self.depth_head is not None:
|
133 |
+
depth, depth_conf = self.depth_head(
|
134 |
+
aggregated_tokens, images=images, patch_start_idx=patch_start_idx
|
135 |
+
)
|
136 |
+
depth = depth[:, 0]
|
137 |
+
depth_conf = depth_conf[:, 0]
|
138 |
+
|
139 |
+
if self.point_head is not None:
|
140 |
+
pts3d, pts3d_conf = self.point_head(
|
141 |
+
aggregated_tokens, images=images, patch_start_idx=patch_start_idx
|
142 |
+
)
|
143 |
+
pts3d = pts3d[:, 0]
|
144 |
+
pts3d_conf = pts3d_conf[:, 0]
|
145 |
+
|
146 |
+
if self.track_head is not None and query_points is not None:
|
147 |
+
track_list, vis, conf = self.track_head(
|
148 |
+
aggregated_tokens, images=images, patch_start_idx=patch_start_idx, query_points=query_points
|
149 |
+
)
|
150 |
+
track = track_list[-1][:, 0]
|
151 |
+
query_points = track
|
152 |
+
vis = vis[:, 0]
|
153 |
+
track_conf = conf[:, 0]
|
154 |
+
|
155 |
+
all_ress.append({
|
156 |
+
'pts3d_in_other_view': pts3d,
|
157 |
+
'conf': pts3d_conf,
|
158 |
+
'depth': depth,
|
159 |
+
'depth_conf': depth_conf,
|
160 |
+
'camera_pose': camera_pose,
|
161 |
+
**({'valid_mask': frame["valid_mask"]}
|
162 |
+
if 'valid_mask' in frame else {}),
|
163 |
+
|
164 |
+
**({'track': track,
|
165 |
+
'vis': vis,
|
166 |
+
'track_conf': track_conf}
|
167 |
+
if query_points is not None else {})
|
168 |
+
})
|
169 |
+
processed_frames.append(frame)
|
170 |
+
|
171 |
+
output = StreamVGGTOutput(ress=all_ress, views=processed_frames)
|
172 |
+
return output
|
streamvggt/utils/geometry.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
def unproject_depth_map_to_point_map(
|
13 |
+
depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
|
14 |
+
) -> np.ndarray:
|
15 |
+
"""
|
16 |
+
Unproject a batch of depth maps to 3D world coordinates.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
|
20 |
+
extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
|
21 |
+
intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
|
25 |
+
"""
|
26 |
+
if isinstance(depth_map, torch.Tensor):
|
27 |
+
depth_map = depth_map.cpu().numpy()
|
28 |
+
if isinstance(extrinsics_cam, torch.Tensor):
|
29 |
+
extrinsics_cam = extrinsics_cam.cpu().numpy()
|
30 |
+
if isinstance(intrinsics_cam, torch.Tensor):
|
31 |
+
intrinsics_cam = intrinsics_cam.cpu().numpy()
|
32 |
+
|
33 |
+
world_points_list = []
|
34 |
+
for frame_idx in range(depth_map.shape[0]):
|
35 |
+
cur_world_points, _, _ = depth_to_world_coords_points(
|
36 |
+
depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
|
37 |
+
)
|
38 |
+
world_points_list.append(cur_world_points)
|
39 |
+
world_points_array = np.stack(world_points_list, axis=0)
|
40 |
+
|
41 |
+
return world_points_array
|
42 |
+
|
43 |
+
|
44 |
+
def depth_to_world_coords_points(
|
45 |
+
depth_map: np.ndarray,
|
46 |
+
extrinsic: np.ndarray,
|
47 |
+
intrinsic: np.ndarray,
|
48 |
+
eps=1e-8,
|
49 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
50 |
+
"""
|
51 |
+
Convert a depth map to world coordinates.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
depth_map (np.ndarray): Depth map of shape (H, W).
|
55 |
+
intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
|
56 |
+
extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
|
60 |
+
"""
|
61 |
+
if depth_map is None:
|
62 |
+
return None, None, None
|
63 |
+
|
64 |
+
# Valid depth mask
|
65 |
+
point_mask = depth_map > eps
|
66 |
+
|
67 |
+
# Convert depth map to camera coordinates
|
68 |
+
cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
|
69 |
+
|
70 |
+
# Multiply with the inverse of extrinsic matrix to transform to world coordinates
|
71 |
+
# extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
|
72 |
+
cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
|
73 |
+
|
74 |
+
R_cam_to_world = cam_to_world_extrinsic[:3, :3]
|
75 |
+
t_cam_to_world = cam_to_world_extrinsic[:3, 3]
|
76 |
+
|
77 |
+
# Apply the rotation and translation to the camera coordinates
|
78 |
+
world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
|
79 |
+
# world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
|
80 |
+
|
81 |
+
return world_coords_points, cam_coords_points, point_mask
|
82 |
+
|
83 |
+
|
84 |
+
def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
85 |
+
"""
|
86 |
+
Convert a depth map to camera coordinates.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
depth_map (np.ndarray): Depth map of shape (H, W).
|
90 |
+
intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
|
94 |
+
"""
|
95 |
+
H, W = depth_map.shape
|
96 |
+
assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
|
97 |
+
assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
|
98 |
+
|
99 |
+
# Intrinsic parameters
|
100 |
+
fu, fv = intrinsic[0, 0], intrinsic[1, 1]
|
101 |
+
cu, cv = intrinsic[0, 2], intrinsic[1, 2]
|
102 |
+
|
103 |
+
# Generate grid of pixel coordinates
|
104 |
+
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
105 |
+
|
106 |
+
# Unproject to camera coordinates
|
107 |
+
x_cam = (u - cu) * depth_map / fu
|
108 |
+
y_cam = (v - cv) * depth_map / fv
|
109 |
+
z_cam = depth_map
|
110 |
+
|
111 |
+
# Stack to form camera coordinates
|
112 |
+
cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
|
113 |
+
|
114 |
+
return cam_coords
|
115 |
+
|
116 |
+
|
117 |
+
def closed_form_inverse_se3(se3, R=None, T=None):
|
118 |
+
"""
|
119 |
+
Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
|
120 |
+
|
121 |
+
If `R` and `T` are provided, they must correspond to the rotation and translation
|
122 |
+
components of `se3`. Otherwise, they will be extracted from `se3`.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
|
126 |
+
R (optional): Nx3x3 array or tensor of rotation matrices.
|
127 |
+
T (optional): Nx3x1 array or tensor of translation vectors.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
Inverted SE3 matrices with the same type and device as `se3`.
|
131 |
+
|
132 |
+
Shapes:
|
133 |
+
se3: (N, 4, 4)
|
134 |
+
R: (N, 3, 3)
|
135 |
+
T: (N, 3, 1)
|
136 |
+
"""
|
137 |
+
# Check if se3 is a numpy array or a torch tensor
|
138 |
+
is_numpy = isinstance(se3, np.ndarray)
|
139 |
+
|
140 |
+
# Validate shapes
|
141 |
+
if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
|
142 |
+
raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
|
143 |
+
|
144 |
+
# Extract R and T if not provided
|
145 |
+
if R is None:
|
146 |
+
R = se3[:, :3, :3] # (N,3,3)
|
147 |
+
if T is None:
|
148 |
+
T = se3[:, :3, 3:] # (N,3,1)
|
149 |
+
|
150 |
+
# Transpose R
|
151 |
+
if is_numpy:
|
152 |
+
# Compute the transpose of the rotation for NumPy
|
153 |
+
R_transposed = np.transpose(R, (0, 2, 1))
|
154 |
+
# -R^T t for NumPy
|
155 |
+
top_right = -np.matmul(R_transposed, T)
|
156 |
+
inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
|
157 |
+
else:
|
158 |
+
R_transposed = R.transpose(1, 2) # (N,3,3)
|
159 |
+
top_right = -torch.bmm(R_transposed, T) # (N,3,1)
|
160 |
+
inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
|
161 |
+
inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
|
162 |
+
|
163 |
+
inverted_matrix[:, :3, :3] = R_transposed
|
164 |
+
inverted_matrix[:, :3, 3:] = top_right
|
165 |
+
|
166 |
+
return inverted_matrix
|
streamvggt/utils/load_fn.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
from torchvision import transforms as TF
|
10 |
+
|
11 |
+
|
12 |
+
def load_and_preprocess_images(image_path_list, mode="crop"):
|
13 |
+
"""
|
14 |
+
A quick start function to load and preprocess images for model input.
|
15 |
+
This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
image_path_list (list): List of paths to image files
|
19 |
+
mode (str, optional): Preprocessing mode, either "crop" or "pad".
|
20 |
+
- "crop" (default): Sets width to 518px and center crops height if needed.
|
21 |
+
- "pad": Preserves all pixels by making the largest dimension 518px
|
22 |
+
and padding the smaller dimension to reach a square shape.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
|
26 |
+
|
27 |
+
Raises:
|
28 |
+
ValueError: If the input list is empty or if mode is invalid
|
29 |
+
|
30 |
+
Notes:
|
31 |
+
- Images with different dimensions will be padded with white (value=1.0)
|
32 |
+
- A warning is printed when images have different shapes
|
33 |
+
- When mode="crop": The function ensures width=518px while maintaining aspect ratio
|
34 |
+
and height is center-cropped if larger than 518px
|
35 |
+
- When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
|
36 |
+
and the smaller dimension is padded to reach a square shape (518x518)
|
37 |
+
- Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
|
38 |
+
"""
|
39 |
+
# Check for empty list
|
40 |
+
if len(image_path_list) == 0:
|
41 |
+
raise ValueError("At least 1 image is required")
|
42 |
+
|
43 |
+
# Validate mode
|
44 |
+
if mode not in ["crop", "pad"]:
|
45 |
+
raise ValueError("Mode must be either 'crop' or 'pad'")
|
46 |
+
|
47 |
+
images = []
|
48 |
+
shapes = set()
|
49 |
+
to_tensor = TF.ToTensor()
|
50 |
+
target_size = 518
|
51 |
+
|
52 |
+
# First process all images and collect their shapes
|
53 |
+
for image_path in image_path_list:
|
54 |
+
|
55 |
+
# Open image
|
56 |
+
img = Image.open(image_path)
|
57 |
+
|
58 |
+
# If there's an alpha channel, blend onto white background:
|
59 |
+
if img.mode == "RGBA":
|
60 |
+
# Create white background
|
61 |
+
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
|
62 |
+
# Alpha composite onto the white background
|
63 |
+
img = Image.alpha_composite(background, img)
|
64 |
+
|
65 |
+
# Now convert to "RGB" (this step assigns white for transparent areas)
|
66 |
+
img = img.convert("RGB")
|
67 |
+
|
68 |
+
width, height = img.size
|
69 |
+
|
70 |
+
if mode == "pad":
|
71 |
+
# Make the largest dimension 518px while maintaining aspect ratio
|
72 |
+
if width >= height:
|
73 |
+
new_width = target_size
|
74 |
+
new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14
|
75 |
+
else:
|
76 |
+
new_height = target_size
|
77 |
+
new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14
|
78 |
+
else: # mode == "crop"
|
79 |
+
# Original behavior: set width to 518px
|
80 |
+
new_width = target_size
|
81 |
+
# Calculate height maintaining aspect ratio, divisible by 14
|
82 |
+
new_height = round(height * (new_width / width) / 14) * 14
|
83 |
+
|
84 |
+
# Resize with new dimensions (width, height)
|
85 |
+
img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
|
86 |
+
img = to_tensor(img) # Convert to tensor (0, 1)
|
87 |
+
|
88 |
+
# Center crop height if it's larger than 518 (only in crop mode)
|
89 |
+
if mode == "crop" and new_height > target_size:
|
90 |
+
start_y = (new_height - target_size) // 2
|
91 |
+
img = img[:, start_y: start_y + target_size, :]
|
92 |
+
|
93 |
+
# For pad mode, pad to make a square of target_size x target_size
|
94 |
+
if mode == "pad":
|
95 |
+
h_padding = target_size - img.shape[1]
|
96 |
+
w_padding = target_size - img.shape[2]
|
97 |
+
|
98 |
+
if h_padding > 0 or w_padding > 0:
|
99 |
+
pad_top = h_padding // 2
|
100 |
+
pad_bottom = h_padding - pad_top
|
101 |
+
pad_left = w_padding // 2
|
102 |
+
pad_right = w_padding - pad_left
|
103 |
+
|
104 |
+
# Pad with white (value=1.0)
|
105 |
+
img = torch.nn.functional.pad(
|
106 |
+
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
|
107 |
+
)
|
108 |
+
|
109 |
+
shapes.add((img.shape[1], img.shape[2]))
|
110 |
+
images.append(img)
|
111 |
+
|
112 |
+
# Check if we have different shapes
|
113 |
+
# In theory our model can also work well with different shapes
|
114 |
+
if len(shapes) > 1:
|
115 |
+
print(f"Warning: Found images with different shapes: {shapes}")
|
116 |
+
# Find maximum dimensions
|
117 |
+
max_height = max(shape[0] for shape in shapes)
|
118 |
+
max_width = max(shape[1] for shape in shapes)
|
119 |
+
|
120 |
+
# Pad images if necessary
|
121 |
+
padded_images = []
|
122 |
+
for img in images:
|
123 |
+
h_padding = max_height - img.shape[1]
|
124 |
+
w_padding = max_width - img.shape[2]
|
125 |
+
|
126 |
+
if h_padding > 0 or w_padding > 0:
|
127 |
+
pad_top = h_padding // 2
|
128 |
+
pad_bottom = h_padding - pad_top
|
129 |
+
pad_left = w_padding // 2
|
130 |
+
pad_right = w_padding - pad_left
|
131 |
+
|
132 |
+
img = torch.nn.functional.pad(
|
133 |
+
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
|
134 |
+
)
|
135 |
+
padded_images.append(img)
|
136 |
+
images = padded_images
|
137 |
+
|
138 |
+
images = torch.stack(images) # concatenate images
|
139 |
+
|
140 |
+
# Ensure correct shape when single image
|
141 |
+
if len(image_path_list) == 1:
|
142 |
+
# Verify shape is (1, C, H, W)
|
143 |
+
if images.dim() == 3:
|
144 |
+
images = images.unsqueeze(0)
|
145 |
+
|
146 |
+
return images
|
vggt/utils/__pycache__/pose_enc.cpython-310.pyc → streamvggt/utils/pose_enc.py
RENAMED
Binary files a/vggt/utils/__pycache__/pose_enc.cpython-310.pyc and b/streamvggt/utils/pose_enc.py differ
|
|
streamvggt/utils/rotation.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
|
15 |
+
"""
|
16 |
+
Quaternion Order: XYZW or say ijkr, scalar-last
|
17 |
+
|
18 |
+
Convert rotations given as quaternions to rotation matrices.
|
19 |
+
Args:
|
20 |
+
quaternions: quaternions with real part last,
|
21 |
+
as tensor of shape (..., 4).
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
25 |
+
"""
|
26 |
+
i, j, k, r = torch.unbind(quaternions, -1)
|
27 |
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
28 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
29 |
+
|
30 |
+
o = torch.stack(
|
31 |
+
(
|
32 |
+
1 - two_s * (j * j + k * k),
|
33 |
+
two_s * (i * j - k * r),
|
34 |
+
two_s * (i * k + j * r),
|
35 |
+
two_s * (i * j + k * r),
|
36 |
+
1 - two_s * (i * i + k * k),
|
37 |
+
two_s * (j * k - i * r),
|
38 |
+
two_s * (i * k - j * r),
|
39 |
+
two_s * (j * k + i * r),
|
40 |
+
1 - two_s * (i * i + j * j),
|
41 |
+
),
|
42 |
+
-1,
|
43 |
+
)
|
44 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
45 |
+
|
46 |
+
|
47 |
+
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
|
48 |
+
"""
|
49 |
+
Convert rotations given as rotation matrices to quaternions.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
quaternions with real part last, as tensor of shape (..., 4).
|
56 |
+
Quaternion Order: XYZW or say ijkr, scalar-last
|
57 |
+
"""
|
58 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
59 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
60 |
+
|
61 |
+
batch_dim = matrix.shape[:-2]
|
62 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
|
63 |
+
|
64 |
+
q_abs = _sqrt_positive_part(
|
65 |
+
torch.stack(
|
66 |
+
[
|
67 |
+
1.0 + m00 + m11 + m22,
|
68 |
+
1.0 + m00 - m11 - m22,
|
69 |
+
1.0 - m00 + m11 - m22,
|
70 |
+
1.0 - m00 - m11 + m22,
|
71 |
+
],
|
72 |
+
dim=-1,
|
73 |
+
)
|
74 |
+
)
|
75 |
+
|
76 |
+
# we produce the desired quaternion multiplied by each of r, i, j, k
|
77 |
+
quat_by_rijk = torch.stack(
|
78 |
+
[
|
79 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
80 |
+
# `int`.
|
81 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
82 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
83 |
+
# `int`.
|
84 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
85 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
86 |
+
# `int`.
|
87 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
88 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
89 |
+
# `int`.
|
90 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
91 |
+
],
|
92 |
+
dim=-2,
|
93 |
+
)
|
94 |
+
|
95 |
+
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
96 |
+
# the candidate won't be picked.
|
97 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
98 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
99 |
+
|
100 |
+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
101 |
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
102 |
+
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
|
103 |
+
|
104 |
+
# Convert from rijk to ijkr
|
105 |
+
out = out[..., [1, 2, 3, 0]]
|
106 |
+
|
107 |
+
out = standardize_quaternion(out)
|
108 |
+
|
109 |
+
return out
|
110 |
+
|
111 |
+
|
112 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
113 |
+
"""
|
114 |
+
Returns torch.sqrt(torch.max(0, x))
|
115 |
+
but with a zero subgradient where x is 0.
|
116 |
+
"""
|
117 |
+
ret = torch.zeros_like(x)
|
118 |
+
positive_mask = x > 0
|
119 |
+
if torch.is_grad_enabled():
|
120 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
121 |
+
else:
|
122 |
+
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
123 |
+
return ret
|
124 |
+
|
125 |
+
|
126 |
+
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
127 |
+
"""
|
128 |
+
Convert a unit quaternion to a standard form: one in which the real
|
129 |
+
part is non negative.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
quaternions: Quaternions with real part last,
|
133 |
+
as tensor of shape (..., 4).
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
Standardized quaternions as tensor of shape (..., 4).
|
137 |
+
"""
|
138 |
+
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
|
streamvggt/utils/visual_track.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
import os
|
11 |
+
|
12 |
+
|
13 |
+
def color_from_xy(x, y, W, H, cmap_name="hsv"):
|
14 |
+
"""
|
15 |
+
Map (x, y) -> color in (R, G, B).
|
16 |
+
1) Normalize x,y to [0,1].
|
17 |
+
2) Combine them into a single scalar c in [0,1].
|
18 |
+
3) Use matplotlib's colormap to convert c -> (R,G,B).
|
19 |
+
|
20 |
+
You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).
|
21 |
+
"""
|
22 |
+
import matplotlib.cm
|
23 |
+
import matplotlib.colors
|
24 |
+
|
25 |
+
x_norm = x / max(W - 1, 1)
|
26 |
+
y_norm = y / max(H - 1, 1)
|
27 |
+
# Simple combination:
|
28 |
+
c = (x_norm + y_norm) / 2.0
|
29 |
+
|
30 |
+
cmap = matplotlib.cm.get_cmap(cmap_name)
|
31 |
+
# cmap(c) -> (r,g,b,a) in [0,1]
|
32 |
+
rgba = cmap(c)
|
33 |
+
r, g, b = rgba[0], rgba[1], rgba[2]
|
34 |
+
return (r, g, b) # in [0,1], RGB order
|
35 |
+
|
36 |
+
|
37 |
+
def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"):
|
38 |
+
"""
|
39 |
+
Given all tracks in one sample (b), compute a (N,3) array of RGB color values
|
40 |
+
in [0,255]. The color is determined by the (x,y) position in the first
|
41 |
+
visible frame for each track.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.
|
45 |
+
vis_mask_b: (S, N) boolean mask; if None, assume all are visible.
|
46 |
+
image_width, image_height: used for normalizing (x, y).
|
47 |
+
cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].
|
51 |
+
"""
|
52 |
+
S, N, _ = tracks_b.shape
|
53 |
+
track_colors = np.zeros((N, 3), dtype=np.uint8)
|
54 |
+
|
55 |
+
if vis_mask_b is None:
|
56 |
+
# treat all as visible
|
57 |
+
vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)
|
58 |
+
|
59 |
+
for i in range(N):
|
60 |
+
# Find first visible frame for track i
|
61 |
+
visible_frames = torch.where(vis_mask_b[:, i])[0]
|
62 |
+
if len(visible_frames) == 0:
|
63 |
+
# track is never visible; just assign black or something
|
64 |
+
track_colors[i] = (0, 0, 0)
|
65 |
+
continue
|
66 |
+
|
67 |
+
first_s = int(visible_frames[0].item())
|
68 |
+
# use that frame's (x,y)
|
69 |
+
x, y = tracks_b[first_s, i].tolist()
|
70 |
+
|
71 |
+
# map (x,y) -> (R,G,B) in [0,1]
|
72 |
+
r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name)
|
73 |
+
# scale to [0,255]
|
74 |
+
r, g, b = int(r * 255), int(g * 255), int(b * 255)
|
75 |
+
track_colors[i] = (r, g, b)
|
76 |
+
|
77 |
+
return track_colors
|
78 |
+
|
79 |
+
|
80 |
+
def visualize_tracks_on_images(
|
81 |
+
images,
|
82 |
+
tracks,
|
83 |
+
track_vis_mask=None,
|
84 |
+
out_dir="track_visuals_concat_by_xy",
|
85 |
+
image_format="CHW", # "CHW" or "HWC"
|
86 |
+
normalize_mode="[0,1]",
|
87 |
+
cmap_name="hsv", # e.g. "hsv", "rainbow", "jet"
|
88 |
+
frames_per_row=4, # New parameter for grid layout
|
89 |
+
save_grid=True, # Flag to control whether to save the grid image
|
90 |
+
):
|
91 |
+
"""
|
92 |
+
Visualizes frames in a grid layout with specified frames per row.
|
93 |
+
Each track's color is determined by its (x,y) position
|
94 |
+
in the first visible frame (or frame 0 if always visible).
|
95 |
+
Finally convert the BGR result to RGB before saving.
|
96 |
+
Also saves each individual frame as a separate PNG file.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC.
|
100 |
+
tracks: torch.Tensor (S, N, 2), last dim = (x, y).
|
101 |
+
track_vis_mask: torch.Tensor (S, N) or None.
|
102 |
+
out_dir: folder to save visualizations.
|
103 |
+
image_format: "CHW" or "HWC".
|
104 |
+
normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255
|
105 |
+
cmap_name: a matplotlib colormap name for color_from_xy.
|
106 |
+
frames_per_row: number of frames to display in each row of the grid.
|
107 |
+
save_grid: whether to save all frames in one grid image.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
None (saves images in out_dir).
|
111 |
+
"""
|
112 |
+
|
113 |
+
if len(tracks.shape) == 4:
|
114 |
+
tracks = tracks.squeeze(0)
|
115 |
+
images = images.squeeze(0)
|
116 |
+
if track_vis_mask is not None:
|
117 |
+
track_vis_mask = track_vis_mask.squeeze(0)
|
118 |
+
|
119 |
+
import matplotlib
|
120 |
+
|
121 |
+
matplotlib.use("Agg") # for non-interactive (optional)
|
122 |
+
|
123 |
+
os.makedirs(out_dir, exist_ok=True)
|
124 |
+
|
125 |
+
S = images.shape[0]
|
126 |
+
_, N, _ = tracks.shape # (S, N, 2)
|
127 |
+
|
128 |
+
# Move to CPU
|
129 |
+
images = images.cpu().clone()
|
130 |
+
tracks = tracks.cpu().clone()
|
131 |
+
if track_vis_mask is not None:
|
132 |
+
track_vis_mask = track_vis_mask.cpu().clone()
|
133 |
+
|
134 |
+
# Infer H, W from images shape
|
135 |
+
if image_format == "CHW":
|
136 |
+
# e.g. images[s].shape = (3, H, W)
|
137 |
+
H, W = images.shape[2], images.shape[3]
|
138 |
+
else:
|
139 |
+
# e.g. images[s].shape = (H, W, 3)
|
140 |
+
H, W = images.shape[1], images.shape[2]
|
141 |
+
|
142 |
+
# Pre-compute the color for each track i based on first visible position
|
143 |
+
track_colors_rgb = get_track_colors_by_position(
|
144 |
+
tracks, # shape (S, N, 2)
|
145 |
+
vis_mask_b=track_vis_mask if track_vis_mask is not None else None,
|
146 |
+
image_width=W,
|
147 |
+
image_height=H,
|
148 |
+
cmap_name=cmap_name,
|
149 |
+
)
|
150 |
+
|
151 |
+
# We'll accumulate each frame's drawn image in a list
|
152 |
+
frame_images = []
|
153 |
+
|
154 |
+
for s in range(S):
|
155 |
+
# shape => either (3, H, W) or (H, W, 3)
|
156 |
+
img = images[s]
|
157 |
+
|
158 |
+
# Convert to (H, W, 3)
|
159 |
+
if image_format == "CHW":
|
160 |
+
img = img.permute(1, 2, 0) # (H, W, 3)
|
161 |
+
# else "HWC", do nothing
|
162 |
+
|
163 |
+
img = img.numpy().astype(np.float32)
|
164 |
+
|
165 |
+
# Scale to [0,255] if needed
|
166 |
+
if normalize_mode == "[0,1]":
|
167 |
+
img = np.clip(img, 0, 1) * 255.0
|
168 |
+
elif normalize_mode == "[-1,1]":
|
169 |
+
img = (img + 1.0) * 0.5 * 255.0
|
170 |
+
img = np.clip(img, 0, 255.0)
|
171 |
+
# else no normalization
|
172 |
+
|
173 |
+
# Convert to uint8
|
174 |
+
img = img.astype(np.uint8)
|
175 |
+
|
176 |
+
# For drawing in OpenCV, convert to BGR
|
177 |
+
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
178 |
+
|
179 |
+
# Draw each visible track
|
180 |
+
cur_tracks = tracks[s] # shape (N, 2)
|
181 |
+
if track_vis_mask is not None:
|
182 |
+
valid_indices = torch.where(track_vis_mask[s])[0]
|
183 |
+
else:
|
184 |
+
valid_indices = range(N)
|
185 |
+
|
186 |
+
cur_tracks_np = cur_tracks.numpy()
|
187 |
+
for i in valid_indices:
|
188 |
+
x, y = cur_tracks_np[i]
|
189 |
+
pt = (int(round(x)), int(round(y)))
|
190 |
+
|
191 |
+
# track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR
|
192 |
+
R, G, B = track_colors_rgb[i]
|
193 |
+
color_bgr = (int(B), int(G), int(R))
|
194 |
+
cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)
|
195 |
+
|
196 |
+
# Convert back to RGB for consistent final saving:
|
197 |
+
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
198 |
+
|
199 |
+
# Save individual frame
|
200 |
+
frame_path = os.path.join(out_dir, f"frame_{s:04d}.png")
|
201 |
+
# Convert to BGR for OpenCV imwrite
|
202 |
+
frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
|
203 |
+
cv2.imwrite(frame_path, frame_bgr)
|
204 |
+
|
205 |
+
frame_images.append(img_rgb)
|
206 |
+
|
207 |
+
# Only create and save the grid image if save_grid is True
|
208 |
+
if save_grid:
|
209 |
+
# Calculate grid dimensions
|
210 |
+
num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division
|
211 |
+
|
212 |
+
# Create a grid of images
|
213 |
+
grid_img = None
|
214 |
+
for row in range(num_rows):
|
215 |
+
start_idx = row * frames_per_row
|
216 |
+
end_idx = min(start_idx + frames_per_row, S)
|
217 |
+
|
218 |
+
# Concatenate this row horizontally
|
219 |
+
row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1)
|
220 |
+
|
221 |
+
# If this row has fewer than frames_per_row images, pad with black
|
222 |
+
if end_idx - start_idx < frames_per_row:
|
223 |
+
padding_width = (frames_per_row - (end_idx - start_idx)) * W
|
224 |
+
padding = np.zeros((H, padding_width, 3), dtype=np.uint8)
|
225 |
+
row_img = np.concatenate([row_img, padding], axis=1)
|
226 |
+
|
227 |
+
# Add this row to the grid
|
228 |
+
if grid_img is None:
|
229 |
+
grid_img = row_img
|
230 |
+
else:
|
231 |
+
grid_img = np.concatenate([grid_img, row_img], axis=0)
|
232 |
+
|
233 |
+
out_path = os.path.join(out_dir, "tracks_grid.png")
|
234 |
+
# Convert back to BGR for OpenCV imwrite
|
235 |
+
grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR)
|
236 |
+
cv2.imwrite(out_path, grid_img_bgr)
|
237 |
+
print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}")
|
238 |
+
|
239 |
+
print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")
|
vggt/heads/__pycache__/camera_head.cpython-310.pyc
DELETED
Binary file (4.27 kB)
|
|
vggt/heads/__pycache__/camera_head.cpython-311.pyc
DELETED
Binary file (6.8 kB)
|
|
vggt/heads/__pycache__/camera_head.cpython-312.pyc
DELETED
Binary file (6.13 kB)
|
|
vggt/heads/__pycache__/dpt_head.cpython-310.pyc
DELETED
Binary file (12.6 kB)
|
|
vggt/heads/__pycache__/dpt_head.cpython-311.pyc
DELETED
Binary file (21.7 kB)
|
|
vggt/heads/__pycache__/dpt_head.cpython-312.pyc
DELETED
Binary file (20.3 kB)
|
|
vggt/heads/__pycache__/head_act.cpython-311.pyc
DELETED
Binary file (4.87 kB)
|
|