lch01 commited on
Commit
28c1b3e
·
1 Parent(s): a77de93

update to the published ver

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +2 -2
  2. streamvggt/heads/camera_head.py +175 -0
  3. streamvggt/heads/dpt_head.py +472 -0
  4. vggt/heads/__pycache__/head_act.cpython-310.pyc → streamvggt/heads/head_act.py +0 -0
  5. vggt/heads/__pycache__/track_head.cpython-310.pyc → streamvggt/heads/track_head.py +0 -0
  6. streamvggt/heads/track_modules/__init__.py +0 -0
  7. streamvggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc +0 -0
  8. streamvggt/heads/track_modules/__pycache__/__init__.cpython-311.pyc +0 -0
  9. streamvggt/heads/track_modules/__pycache__/__init__.cpython-312.pyc +0 -0
  10. streamvggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc +0 -0
  11. streamvggt/heads/track_modules/__pycache__/base_track_predictor.cpython-311.pyc +0 -0
  12. streamvggt/heads/track_modules/__pycache__/base_track_predictor.cpython-312.pyc +0 -0
  13. streamvggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc +0 -0
  14. streamvggt/heads/track_modules/__pycache__/blocks.cpython-311.pyc +0 -0
  15. streamvggt/heads/track_modules/__pycache__/blocks.cpython-312.pyc +0 -0
  16. streamvggt/heads/track_modules/__pycache__/modules.cpython-310.pyc +0 -0
  17. streamvggt/heads/track_modules/__pycache__/modules.cpython-311.pyc +0 -0
  18. streamvggt/heads/track_modules/__pycache__/modules.cpython-312.pyc +0 -0
  19. streamvggt/heads/track_modules/__pycache__/utils.cpython-310.pyc +0 -0
  20. streamvggt/heads/track_modules/__pycache__/utils.cpython-311.pyc +0 -0
  21. streamvggt/heads/track_modules/__pycache__/utils.cpython-312.pyc +0 -0
  22. streamvggt/heads/track_modules/base_track_predictor.py +195 -0
  23. streamvggt/heads/track_modules/blocks.py +237 -0
  24. streamvggt/heads/track_modules/modules.py +211 -0
  25. streamvggt/heads/track_modules/utils.py +216 -0
  26. vggt/heads/__pycache__/utils.cpython-310.pyc → streamvggt/heads/utils.py +0 -0
  27. streamvggt/layers/__init__.py +5 -0
  28. streamvggt/layers/attention.py +129 -0
  29. streamvggt/layers/block.py +263 -0
  30. streamvggt/layers/drop_path.py +24 -0
  31. streamvggt/layers/layer_scale.py +20 -0
  32. streamvggt/layers/mlp.py +30 -0
  33. streamvggt/layers/patch_embed.py +79 -0
  34. vggt/layers/__pycache__/rope.cpython-310.pyc → streamvggt/layers/rope.py +0 -0
  35. streamvggt/layers/swiglu_ffn.py +67 -0
  36. streamvggt/layers/vision_transformer.py +398 -0
  37. streamvggt/models/aggregator.py +394 -0
  38. streamvggt/models/streamvggt.py +172 -0
  39. streamvggt/utils/geometry.py +166 -0
  40. streamvggt/utils/load_fn.py +146 -0
  41. vggt/utils/__pycache__/pose_enc.cpython-310.pyc → streamvggt/utils/pose_enc.py +0 -0
  42. streamvggt/utils/rotation.py +138 -0
  43. streamvggt/utils/visual_track.py +239 -0
  44. vggt/heads/__pycache__/camera_head.cpython-310.pyc +0 -0
  45. vggt/heads/__pycache__/camera_head.cpython-311.pyc +0 -0
  46. vggt/heads/__pycache__/camera_head.cpython-312.pyc +0 -0
  47. vggt/heads/__pycache__/dpt_head.cpython-310.pyc +0 -0
  48. vggt/heads/__pycache__/dpt_head.cpython-311.pyc +0 -0
  49. vggt/heads/__pycache__/dpt_head.cpython-312.pyc +0 -0
  50. vggt/heads/__pycache__/head_act.cpython-311.pyc +0 -0
app.py CHANGED
@@ -29,7 +29,7 @@ import gc
29
  import time
30
 
31
  from visual_util import predictions_to_glb
32
- from vggt.models.vggt import VGGT
33
  from vggt.utils.load_fn import load_and_preprocess_images
34
  from vggt.utils.pose_enc import pose_encoding_to_extri_intri
35
  from vggt.utils.geometry import unproject_depth_map_to_point_map
@@ -45,7 +45,7 @@ path = hf_hub_download(
45
  revision="main",
46
  force_download=True
47
  )
48
- model = VGGT(use_causal_global=True, use_distil=True)
49
  ckpt = torch.load(path, map_location=device)
50
  model.load_state_dict(ckpt, strict=True)
51
  model = model.to(device)
 
29
  import time
30
 
31
  from visual_util import predictions_to_glb
32
+ from streamvggt.models.streamvggt import StreamVGGT
33
  from vggt.utils.load_fn import load_and_preprocess_images
34
  from vggt.utils.pose_enc import pose_encoding_to_extri_intri
35
  from vggt.utils.geometry import unproject_depth_map_to_point_map
 
45
  revision="main",
46
  force_download=True
47
  )
48
+ model = StreamVGGT()
49
  ckpt = torch.load(path, map_location=device)
50
  model.load_state_dict(ckpt, strict=True)
51
  model = model.to(device)
streamvggt/heads/camera_head.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from streamvggt.layers import Mlp
9
+ from streamvggt.layers.block import Block
10
+ from streamvggt.heads.head_act import activate_pose
11
+
12
+
13
+ class CameraHead(nn.Module):
14
+ def __init__(
15
+ self,
16
+ dim_in: int = 2048,
17
+ trunk_depth: int = 4,
18
+ pose_encoding_type: str = "absT_quaR_FoV",
19
+ num_heads: int = 16,
20
+ mlp_ratio: int = 4,
21
+ init_values: float = 0.01,
22
+ trans_act: str = "linear",
23
+ quat_act: str = "linear",
24
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
25
+ ):
26
+ super().__init__()
27
+
28
+ if pose_encoding_type == "absT_quaR_FoV":
29
+ self.target_dim = 9
30
+ else:
31
+ raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
32
+
33
+ self.trans_act = trans_act
34
+ self.quat_act = quat_act
35
+ self.fl_act = fl_act
36
+ self.trunk_depth = trunk_depth
37
+
38
+ # Build the trunk using a sequence of transformer blocks.
39
+ self.trunk = nn.Sequential(
40
+ *[
41
+ Block(
42
+ dim=dim_in,
43
+ num_heads=num_heads,
44
+ mlp_ratio=mlp_ratio,
45
+ init_values=init_values,
46
+ )
47
+ for _ in range(trunk_depth)
48
+ ]
49
+ )
50
+
51
+ # Normalizations for camera token and trunk output.
52
+ self.token_norm = nn.LayerNorm(dim_in)
53
+ self.trunk_norm = nn.LayerNorm(dim_in)
54
+
55
+ # Learnable empty camera pose token.
56
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
57
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
58
+
59
+ # Module for producing modulation parameters: shift, scale, and a gate.
60
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
61
+
62
+ # Adaptive layer normalization without affine parameters.
63
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
64
+ self.pose_branch = Mlp(
65
+ in_features=dim_in,
66
+ hidden_features=dim_in // 2,
67
+ out_features=self.target_dim,
68
+ drop=0,
69
+ )
70
+
71
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4, past_key_values_camera = None, use_cache: bool = False) -> list:
72
+ """
73
+ Forward pass to predict camera parameters.
74
+
75
+ Args:
76
+ aggregated_tokens_list (list): List of token tensors from the network;
77
+ the last tensor is used for prediction.
78
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
79
+
80
+ Returns:
81
+ list: A list of predicted camera encodings (post-activation) from each iteration.
82
+ """
83
+ # Use tokens from the last block for camera prediction.
84
+ tokens = aggregated_tokens_list[-1]
85
+
86
+ # Extract the camera tokens
87
+ pose_tokens = tokens[:, :, 0]
88
+ pose_tokens = self.token_norm(pose_tokens)
89
+
90
+ if use_cache:
91
+ pred_pose_enc_list, past_key_values_camera = self.trunk_fn(pose_tokens, num_iterations, past_key_values_camera, use_cache)
92
+ return pred_pose_enc_list, past_key_values_camera
93
+ else:
94
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations, past_key_values_camera=None, use_cache=use_cache)
95
+ return pred_pose_enc_list
96
+
97
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int, past_key_values_camera, use_cache: bool) -> list:
98
+ """
99
+ Iteratively refine camera pose predictions.
100
+
101
+ Args:
102
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
103
+ num_iterations (int): Number of refinement iterations.
104
+
105
+ Returns:
106
+ list: List of activated camera encodings from each iteration.
107
+ """
108
+ B, S, C = pose_tokens.shape # S is expected to be 1.
109
+ pred_pose_enc = None
110
+ pred_pose_enc_list = []
111
+
112
+ for _ in range(num_iterations):
113
+ # Use a learned empty pose for the first iteration.
114
+ if pred_pose_enc is None:
115
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
116
+ else:
117
+ # Detach the previous prediction to avoid backprop through time.
118
+ pred_pose_enc = pred_pose_enc.detach()
119
+ module_input = self.embed_pose(pred_pose_enc)
120
+
121
+ # Generate modulation parameters and split them into shift, scale, and gate components.
122
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
123
+
124
+ # Adaptive layer normalization and modulation.
125
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
126
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
127
+
128
+ if not use_cache:
129
+ L = S * 1
130
+ frame_ids = torch.arange(L, device=pose_tokens_modulated.device) // 1 # [0,0,...,1,1,...,S-1]
131
+ future_frame = frame_ids.unsqueeze(1) < frame_ids.unsqueeze(0)
132
+ attn_mask = future_frame.to(pose_tokens_modulated.dtype) * torch.finfo(pose_tokens_modulated.dtype).min
133
+ else:
134
+ attn_mask = None
135
+
136
+ if use_cache:
137
+ for idx in range(self.trunk_depth):
138
+ pose_tokens_modulated, block_kv = self.trunk[idx](
139
+ pose_tokens_modulated,
140
+ attn_mask=attn_mask,
141
+ past_key_values=past_key_values_camera[idx] if past_key_values_camera[idx] is not None else None,
142
+ use_cache=True
143
+ )
144
+ past_key_values_camera[idx] = block_kv
145
+ else:
146
+ for idx in range(self.trunk_depth):
147
+ pose_tokens_modulated = self.trunk[idx](pose_tokens_modulated, attn_mask=attn_mask)
148
+
149
+ # Compute the delta update for the pose encoding.
150
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
151
+
152
+ if pred_pose_enc is None:
153
+ pred_pose_enc = pred_pose_enc_delta
154
+ else:
155
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
156
+
157
+ # Apply final activation functions for translation, quaternion, and field-of-view.
158
+ activated_pose = activate_pose(
159
+ pred_pose_enc,
160
+ trans_act=self.trans_act,
161
+ quat_act=self.quat_act,
162
+ fl_act=self.fl_act,
163
+ )
164
+ pred_pose_enc_list.append(activated_pose)
165
+
166
+ if use_cache:
167
+ return pred_pose_enc_list, past_key_values_camera
168
+ return pred_pose_enc_list
169
+
170
+
171
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
172
+ """
173
+ Modulate the input tensor using scaling and shifting parameters.
174
+ """
175
+ return x * (1 + scale) + shift
streamvggt/heads/dpt_head.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Dict, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from .head_act import activate_head
8
+ from .utils import create_uv_grid, position_grid_to_embed
9
+
10
+
11
+ class DPTHead(nn.Module):
12
+ """
13
+ Args:
14
+ dim_in (int): Input dimension (channels).
15
+ patch_size (int, optional): Patch size. Default is 14.
16
+ output_dim (int, optional): Number of output channels. Default is 4.
17
+ activation (str, optional): Activation type. Default is "inv_log".
18
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
19
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
20
+ out_channels (List[int], optional): Output channels for each intermediate layer.
21
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
22
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
23
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
24
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ dim_in: int,
30
+ patch_size: int = 14,
31
+ output_dim: int = 4,
32
+ activation: str = "inv_log",
33
+ conf_activation: str = "expp1",
34
+ features: int = 256,
35
+ out_channels: List[int] = [256, 512, 1024, 1024],
36
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
37
+ pos_embed: bool = True,
38
+ feature_only: bool = False,
39
+ down_ratio: int = 1,
40
+ ) -> None:
41
+ super(DPTHead, self).__init__()
42
+ self.patch_size = patch_size
43
+ self.activation = activation
44
+ self.conf_activation = conf_activation
45
+ self.pos_embed = pos_embed
46
+ self.feature_only = feature_only
47
+ self.down_ratio = down_ratio
48
+ self.intermediate_layer_idx = intermediate_layer_idx
49
+
50
+ self.norm = nn.LayerNorm(dim_in)
51
+
52
+ # Projection layers for each output channel from tokens.
53
+ self.projects = nn.ModuleList(
54
+ [
55
+ nn.Conv2d(
56
+ in_channels=dim_in,
57
+ out_channels=oc,
58
+ kernel_size=1,
59
+ stride=1,
60
+ padding=0,
61
+ )
62
+ for oc in out_channels
63
+ ]
64
+ )
65
+
66
+ # Resize layers for upsampling feature maps.
67
+ self.resize_layers = nn.ModuleList(
68
+ [
69
+ nn.ConvTranspose2d(
70
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
71
+ ),
72
+ nn.ConvTranspose2d(
73
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
74
+ ),
75
+ nn.Identity(),
76
+ nn.Conv2d(
77
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
78
+ ),
79
+ ]
80
+ )
81
+
82
+ self.scratch = _make_scratch(
83
+ out_channels,
84
+ features,
85
+ expand=False,
86
+ )
87
+
88
+ # Attach additional modules to scratch.
89
+ self.scratch.stem_transpose = None
90
+ self.scratch.refinenet1 = _make_fusion_block(features)
91
+ self.scratch.refinenet2 = _make_fusion_block(features)
92
+ self.scratch.refinenet3 = _make_fusion_block(features)
93
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
94
+
95
+ head_features_1 = features
96
+ head_features_2 = 32
97
+
98
+ if feature_only:
99
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
100
+ else:
101
+ self.scratch.output_conv1 = nn.Conv2d(
102
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
103
+ )
104
+ conv2_in_channels = head_features_1 // 2
105
+
106
+ self.scratch.output_conv2 = nn.Sequential(
107
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
108
+ nn.ReLU(inplace=True),
109
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
110
+ )
111
+
112
+ def forward(
113
+ self,
114
+ aggregated_tokens_list: List[torch.Tensor],
115
+ images: torch.Tensor,
116
+ patch_start_idx: int,
117
+ frames_chunk_size: int = 8,
118
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
119
+ """
120
+ Forward pass through the DPT head, supports processing by chunking frames.
121
+ Args:
122
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
123
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
124
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
125
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
126
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
127
+ If None or larger than S, all frames are processed at once. Default: 8.
128
+
129
+ Returns:
130
+ Tensor or Tuple[Tensor, Tensor]:
131
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
132
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
133
+ """
134
+ B, S, _, H, W = images.shape
135
+
136
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
137
+ if frames_chunk_size is None or frames_chunk_size >= S:
138
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
139
+
140
+ # Otherwise, process frames in chunks to manage memory usage
141
+ assert frames_chunk_size > 0
142
+
143
+ # Process frames in batches
144
+ all_preds = []
145
+ all_conf = []
146
+
147
+ for frames_start_idx in range(0, S, frames_chunk_size):
148
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
149
+
150
+ # Process batch of frames
151
+ if self.feature_only:
152
+ chunk_output = self._forward_impl(
153
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
154
+ )
155
+ all_preds.append(chunk_output)
156
+ else:
157
+ chunk_preds, chunk_conf = self._forward_impl(
158
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
159
+ )
160
+ all_preds.append(chunk_preds)
161
+ all_conf.append(chunk_conf)
162
+
163
+ # Concatenate results along the sequence dimension
164
+ if self.feature_only:
165
+ return torch.cat(all_preds, dim=1)
166
+ else:
167
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
168
+
169
+ def _forward_impl(
170
+ self,
171
+ aggregated_tokens_list: List[torch.Tensor],
172
+ images: torch.Tensor,
173
+ patch_start_idx: int,
174
+ frames_start_idx: int = None,
175
+ frames_end_idx: int = None,
176
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
177
+ """
178
+ Args:
179
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
180
+ images (Tensor): Input images with shape [B, S, 3, H, W].
181
+ patch_start_idx (int): Starting index for patch tokens.
182
+ frames_start_idx (int, optional): Starting index for frames to process.
183
+ frames_end_idx (int, optional): Ending index for frames to process.
184
+
185
+ Returns:
186
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
187
+ """
188
+ if frames_start_idx is not None and frames_end_idx is not None:
189
+ images = images[:, frames_start_idx:frames_end_idx].contiguous()
190
+
191
+ B, S, _, H, W = images.shape
192
+
193
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
194
+
195
+ out = []
196
+ dpt_idx = 0
197
+
198
+ for layer_idx in self.intermediate_layer_idx:
199
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
200
+
201
+ # Select frames if processing a chunk
202
+ if frames_start_idx is not None and frames_end_idx is not None:
203
+ x = x[:, frames_start_idx:frames_end_idx]
204
+
205
+ x = x.reshape(B * S, -1, x.shape[-1])
206
+
207
+ x = self.norm(x)
208
+
209
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
210
+
211
+ x = self.projects[dpt_idx](x)
212
+ if self.pos_embed:
213
+ x = self._apply_pos_embed(x, W, H)
214
+ x = self.resize_layers[dpt_idx](x)
215
+
216
+ out.append(x)
217
+ dpt_idx += 1
218
+
219
+ # Fuse features from multiple layers.
220
+ out = self.scratch_forward(out)
221
+ # Interpolate fused output to match target image resolution.
222
+ out = custom_interpolate(
223
+ out,
224
+ (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
225
+ mode="bilinear",
226
+ align_corners=True,
227
+ )
228
+
229
+ if self.pos_embed:
230
+ out = self._apply_pos_embed(out, W, H)
231
+
232
+ if self.feature_only:
233
+ return out.reshape(B, S, *out.shape[1:])
234
+
235
+ out = self.scratch.output_conv2(out)
236
+ preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
237
+
238
+ preds = preds.reshape(B, S, *preds.shape[1:])
239
+ conf = conf.reshape(B, S, *conf.shape[1:])
240
+ return preds, conf
241
+
242
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
243
+ """
244
+ Apply positional embedding to tensor x.
245
+ """
246
+ patch_w = x.shape[-1]
247
+ patch_h = x.shape[-2]
248
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
249
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
250
+ pos_embed = pos_embed * ratio
251
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
252
+ return x + pos_embed
253
+
254
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
255
+ """
256
+ Forward pass through the fusion blocks.
257
+
258
+ Args:
259
+ features (List[Tensor]): List of feature maps from different layers.
260
+
261
+ Returns:
262
+ Tensor: Fused feature map.
263
+ """
264
+ layer_1, layer_2, layer_3, layer_4 = features
265
+
266
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
267
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
268
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
269
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
270
+
271
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
272
+ del layer_4_rn, layer_4
273
+
274
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
275
+ del layer_3_rn, layer_3
276
+
277
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
278
+ del layer_2_rn, layer_2
279
+
280
+ out = self.scratch.refinenet1(out, layer_1_rn)
281
+ del layer_1_rn, layer_1
282
+
283
+ out = self.scratch.output_conv1(out)
284
+ return out
285
+
286
+
287
+ def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
288
+ return FeatureFusionBlock(
289
+ features,
290
+ nn.ReLU(inplace=True),
291
+ deconv=False,
292
+ bn=False,
293
+ expand=False,
294
+ align_corners=True,
295
+ size=size,
296
+ has_residual=has_residual,
297
+ groups=groups,
298
+ )
299
+
300
+
301
+ def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
302
+ scratch = nn.Module()
303
+ out_shape1 = out_shape
304
+ out_shape2 = out_shape
305
+ out_shape3 = out_shape
306
+ if len(in_shape) >= 4:
307
+ out_shape4 = out_shape
308
+
309
+ if expand:
310
+ out_shape1 = out_shape
311
+ out_shape2 = out_shape * 2
312
+ out_shape3 = out_shape * 4
313
+ if len(in_shape) >= 4:
314
+ out_shape4 = out_shape * 8
315
+
316
+ scratch.layer1_rn = nn.Conv2d(
317
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
318
+ )
319
+ scratch.layer2_rn = nn.Conv2d(
320
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
321
+ )
322
+ scratch.layer3_rn = nn.Conv2d(
323
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
324
+ )
325
+ if len(in_shape) >= 4:
326
+ scratch.layer4_rn = nn.Conv2d(
327
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
328
+ )
329
+ return scratch
330
+
331
+
332
+ class ResidualConvUnit(nn.Module):
333
+ """Residual convolution module."""
334
+
335
+ def __init__(self, features, activation, bn, groups=1):
336
+ """Init.
337
+
338
+ Args:
339
+ features (int): number of features
340
+ """
341
+ super().__init__()
342
+
343
+ self.bn = bn
344
+ self.groups = groups
345
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
346
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
347
+
348
+ self.norm1 = None
349
+ self.norm2 = None
350
+
351
+ self.activation = activation
352
+ self.skip_add = nn.quantized.FloatFunctional()
353
+
354
+ def forward(self, x):
355
+ """Forward pass.
356
+
357
+ Args:
358
+ x (tensor): input
359
+
360
+ Returns:
361
+ tensor: output
362
+ """
363
+
364
+ out = self.activation(x)
365
+ out = self.conv1(out)
366
+ if self.norm1 is not None:
367
+ out = self.norm1(out)
368
+
369
+ out = self.activation(out)
370
+ out = self.conv2(out)
371
+ if self.norm2 is not None:
372
+ out = self.norm2(out)
373
+
374
+ return self.skip_add.add(out, x)
375
+
376
+
377
+ class FeatureFusionBlock(nn.Module):
378
+ """Feature fusion block."""
379
+
380
+ def __init__(
381
+ self,
382
+ features,
383
+ activation,
384
+ deconv=False,
385
+ bn=False,
386
+ expand=False,
387
+ align_corners=True,
388
+ size=None,
389
+ has_residual=True,
390
+ groups=1,
391
+ ):
392
+ """Init.
393
+
394
+ Args:
395
+ features (int): number of features
396
+ """
397
+ super(FeatureFusionBlock, self).__init__()
398
+
399
+ self.deconv = deconv
400
+ self.align_corners = align_corners
401
+ self.groups = groups
402
+ self.expand = expand
403
+ out_features = features
404
+ if self.expand == True:
405
+ out_features = features // 2
406
+
407
+ self.out_conv = nn.Conv2d(
408
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
409
+ )
410
+
411
+ if has_residual:
412
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
413
+
414
+ self.has_residual = has_residual
415
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
416
+
417
+ self.skip_add = nn.quantized.FloatFunctional()
418
+ self.size = size
419
+
420
+ def forward(self, *xs, size=None):
421
+ """Forward pass.
422
+
423
+ Returns:
424
+ tensor: output
425
+ """
426
+ output = xs[0]
427
+
428
+ if self.has_residual:
429
+ res = self.resConfUnit1(xs[1])
430
+ output = self.skip_add.add(output, res)
431
+
432
+ output = self.resConfUnit2(output)
433
+
434
+ if (size is None) and (self.size is None):
435
+ modifier = {"scale_factor": 2}
436
+ elif size is None:
437
+ modifier = {"size": self.size}
438
+ else:
439
+ modifier = {"size": size}
440
+
441
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
442
+ output = self.out_conv(output)
443
+
444
+ return output
445
+
446
+
447
+ def custom_interpolate(
448
+ x: torch.Tensor,
449
+ size: Tuple[int, int] = None,
450
+ scale_factor: float = None,
451
+ mode: str = "bilinear",
452
+ align_corners: bool = True,
453
+ ) -> torch.Tensor:
454
+ """
455
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
456
+ """
457
+ if size is None:
458
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
459
+
460
+ INT_MAX = 1610612736
461
+
462
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
463
+
464
+ if input_elements > INT_MAX:
465
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
466
+ interpolated_chunks = [
467
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
468
+ ]
469
+ x = torch.cat(interpolated_chunks, dim=0)
470
+ return x.contiguous()
471
+ else:
472
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
vggt/heads/__pycache__/head_act.cpython-310.pyc → streamvggt/heads/head_act.py RENAMED
Binary files a/vggt/heads/__pycache__/head_act.cpython-310.pyc and b/streamvggt/heads/head_act.py differ
 
vggt/heads/__pycache__/track_head.cpython-310.pyc → streamvggt/heads/track_head.py RENAMED
Binary files a/vggt/heads/__pycache__/track_head.cpython-310.pyc and b/streamvggt/heads/track_head.py differ
 
streamvggt/heads/track_modules/__init__.py ADDED
File without changes
streamvggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (157 Bytes). View file
 
streamvggt/heads/track_modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (175 Bytes). View file
 
streamvggt/heads/track_modules/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (165 Bytes). View file
 
streamvggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc ADDED
Binary file (4.27 kB). View file
 
streamvggt/heads/track_modules/__pycache__/base_track_predictor.cpython-311.pyc ADDED
Binary file (9.38 kB). View file
 
streamvggt/heads/track_modules/__pycache__/base_track_predictor.cpython-312.pyc ADDED
Binary file (8.77 kB). View file
 
streamvggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc ADDED
Binary file (6.58 kB). View file
 
streamvggt/heads/track_modules/__pycache__/blocks.cpython-311.pyc ADDED
Binary file (12.9 kB). View file
 
streamvggt/heads/track_modules/__pycache__/blocks.cpython-312.pyc ADDED
Binary file (11.6 kB). View file
 
streamvggt/heads/track_modules/__pycache__/modules.cpython-310.pyc ADDED
Binary file (5.27 kB). View file
 
streamvggt/heads/track_modules/__pycache__/modules.cpython-311.pyc ADDED
Binary file (10 kB). View file
 
streamvggt/heads/track_modules/__pycache__/modules.cpython-312.pyc ADDED
Binary file (8.79 kB). View file
 
streamvggt/heads/track_modules/__pycache__/utils.cpython-310.pyc ADDED
Binary file (7.36 kB). View file
 
streamvggt/heads/track_modules/__pycache__/utils.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
streamvggt/heads/track_modules/__pycache__/utils.cpython-312.pyc ADDED
Binary file (10.4 kB). View file
 
streamvggt/heads/track_modules/base_track_predictor.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import rearrange, repeat
4
+
5
+
6
+ from .blocks import EfficientUpdateFormer, CorrBlock
7
+ from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
8
+ from .modules import Mlp
9
+
10
+
11
+ class BaseTrackerPredictor(nn.Module):
12
+ def __init__(
13
+ self,
14
+ stride=1,
15
+ corr_levels=5,
16
+ corr_radius=4,
17
+ latent_dim=128,
18
+ hidden_size=384,
19
+ use_spaceatt=True,
20
+ depth=6,
21
+ max_scale=518,
22
+ predict_conf=True,
23
+ ):
24
+ super(BaseTrackerPredictor, self).__init__()
25
+ self.stride = stride
26
+ self.latent_dim = latent_dim
27
+ self.corr_levels = corr_levels
28
+ self.corr_radius = corr_radius
29
+ self.hidden_size = hidden_size
30
+ self.max_scale = max_scale
31
+ self.predict_conf = predict_conf
32
+
33
+ self.flows_emb_dim = latent_dim // 2
34
+
35
+ self.corr_mlp = Mlp(
36
+ in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
37
+ hidden_features=self.hidden_size,
38
+ out_features=self.latent_dim,
39
+ )
40
+
41
+ self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
42
+
43
+ self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
44
+
45
+ space_depth = depth if use_spaceatt else 0
46
+ time_depth = depth
47
+
48
+ self.updateformer = EfficientUpdateFormer(
49
+ space_depth=space_depth,
50
+ time_depth=time_depth,
51
+ input_dim=self.transformer_dim,
52
+ hidden_size=self.hidden_size,
53
+ output_dim=self.latent_dim + 2,
54
+ mlp_ratio=4.0,
55
+ add_space_attn=use_spaceatt,
56
+ )
57
+
58
+ self.fmap_norm = nn.LayerNorm(self.latent_dim)
59
+ self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
60
+
61
+ # A linear layer to update track feats at each iteration
62
+ self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
63
+
64
+ self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
65
+
66
+ if predict_conf:
67
+ self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
68
+
69
+ def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):
70
+ """
71
+ query_points: B x N x 2, the number of batches, tracks, and xy
72
+ fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
73
+ note HH and WW is the size of feature maps instead of original images
74
+ """
75
+ B, N, D = query_points.shape
76
+ B, S, C, HH, WW = fmaps.shape
77
+
78
+ assert D == 2, "Input points must be 2D coordinates"
79
+
80
+ # apply a layernorm to fmaps here
81
+ fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
82
+ fmaps = fmaps.permute(0, 1, 4, 2, 3)
83
+
84
+ # Scale the input query_points because we may downsample the images
85
+ # by down_ratio or self.stride
86
+ # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
87
+ # its query_points should be query_points/4
88
+ if down_ratio > 1:
89
+ query_points = query_points / float(down_ratio)
90
+
91
+ query_points = query_points / float(self.stride)
92
+
93
+ # Init with coords as the query points
94
+ # It means the search will start from the position of query points at the reference frames
95
+ coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
96
+
97
+ # Sample/extract the features of the query points in the query frame
98
+ query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
99
+
100
+ # init track feats by query feats
101
+ track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
102
+ # back up the init coords
103
+ coords_backup = coords.clone()
104
+
105
+ fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
106
+
107
+ coord_preds = []
108
+
109
+ # Iterative Refinement
110
+ for _ in range(iters):
111
+ # Detach the gradients from the last iteration
112
+ # (in my experience, not very important for performance)
113
+ coords = coords.detach()
114
+
115
+ fcorrs = fcorr_fn.corr_sample(track_feats, coords)
116
+
117
+ corr_dim = fcorrs.shape[3]
118
+ fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
119
+ fcorrs_ = self.corr_mlp(fcorrs_)
120
+
121
+ # Movement of current coords relative to query points
122
+ flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
123
+
124
+ flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
125
+
126
+ # (In my trials, it is also okay to just add the flows_emb instead of concat)
127
+ flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
128
+
129
+ track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
130
+
131
+ # Concatenate them as the input for the transformers
132
+ transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
133
+
134
+ # 2D positional embed
135
+ pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
136
+ sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
137
+
138
+ sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
139
+
140
+ x = transformer_input + sampled_pos_emb
141
+
142
+ # Add the query ref token to the track feats
143
+ query_ref_token = torch.cat(
144
+ [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
145
+ )
146
+ x = x + query_ref_token.to(x.device).to(x.dtype)
147
+
148
+ # B, N, S, C
149
+ x = rearrange(x, "(b n) s d -> b n s d", b=B)
150
+
151
+ # Compute the delta coordinates and delta track features
152
+ delta, _ = self.updateformer(x)
153
+
154
+ # BN, S, C
155
+ delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
156
+ delta_coords_ = delta[:, :, :2]
157
+ delta_feats_ = delta[:, :, 2:]
158
+
159
+ track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
160
+ delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
161
+
162
+ # Update the track features
163
+ track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
164
+
165
+ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
166
+
167
+ # B x S x N x 2
168
+ coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
169
+
170
+ # Force coord0 as query
171
+ # because we assume the query points should not be changed
172
+ coords[:, 0] = coords_backup[:, 0]
173
+
174
+ # The predicted tracks are in the original image scale
175
+ if down_ratio > 1:
176
+ coord_preds.append(coords * self.stride * down_ratio)
177
+ else:
178
+ coord_preds.append(coords * self.stride)
179
+
180
+ # B, S, N
181
+ vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
182
+ if apply_sigmoid:
183
+ vis_e = torch.sigmoid(vis_e)
184
+
185
+ if self.predict_conf:
186
+ conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
187
+ if apply_sigmoid:
188
+ conf_e = torch.sigmoid(conf_e)
189
+ else:
190
+ conf_e = None
191
+
192
+ if return_feat:
193
+ return coord_preds, vis_e, track_feats, query_track_feat, conf_e
194
+ else:
195
+ return coord_preds, vis_e, conf_e
streamvggt/heads/track_modules/blocks.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .utils import bilinear_sampler
7
+ from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
8
+
9
+
10
+ class EfficientUpdateFormer(nn.Module):
11
+ """
12
+ Transformer model that updates track estimates.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ space_depth=6,
18
+ time_depth=6,
19
+ input_dim=320,
20
+ hidden_size=384,
21
+ num_heads=8,
22
+ output_dim=130,
23
+ mlp_ratio=4.0,
24
+ add_space_attn=True,
25
+ num_virtual_tracks=64,
26
+ ):
27
+ super().__init__()
28
+
29
+ self.out_channels = 2
30
+ self.num_heads = num_heads
31
+ self.hidden_size = hidden_size
32
+ self.add_space_attn = add_space_attn
33
+
34
+ # Add input LayerNorm before linear projection
35
+ self.input_norm = nn.LayerNorm(input_dim)
36
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
37
+
38
+ # Add output LayerNorm before final projection
39
+ self.output_norm = nn.LayerNorm(hidden_size)
40
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
41
+ self.num_virtual_tracks = num_virtual_tracks
42
+
43
+ if self.add_space_attn:
44
+ self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
45
+ else:
46
+ self.virual_tracks = None
47
+
48
+ self.time_blocks = nn.ModuleList(
49
+ [
50
+ AttnBlock(
51
+ hidden_size,
52
+ num_heads,
53
+ mlp_ratio=mlp_ratio,
54
+ attn_class=nn.MultiheadAttention,
55
+ )
56
+ for _ in range(time_depth)
57
+ ]
58
+ )
59
+
60
+ if add_space_attn:
61
+ self.space_virtual_blocks = nn.ModuleList(
62
+ [
63
+ AttnBlock(
64
+ hidden_size,
65
+ num_heads,
66
+ mlp_ratio=mlp_ratio,
67
+ attn_class=nn.MultiheadAttention,
68
+ )
69
+ for _ in range(space_depth)
70
+ ]
71
+ )
72
+ self.space_point2virtual_blocks = nn.ModuleList(
73
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
74
+ )
75
+ self.space_virtual2point_blocks = nn.ModuleList(
76
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
77
+ )
78
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
79
+ self.initialize_weights()
80
+
81
+ def initialize_weights(self):
82
+ def _basic_init(module):
83
+ if isinstance(module, nn.Linear):
84
+ torch.nn.init.xavier_uniform_(module.weight)
85
+ if module.bias is not None:
86
+ nn.init.constant_(module.bias, 0)
87
+ torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
88
+
89
+ self.apply(_basic_init)
90
+
91
+ def forward(self, input_tensor, mask=None):
92
+ # Apply input LayerNorm
93
+ input_tensor = self.input_norm(input_tensor)
94
+ tokens = self.input_transform(input_tensor)
95
+
96
+ init_tokens = tokens
97
+
98
+ B, _, T, _ = tokens.shape
99
+
100
+ if self.add_space_attn:
101
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
102
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
103
+
104
+ _, N, _, _ = tokens.shape
105
+
106
+ j = 0
107
+ for i in range(len(self.time_blocks)):
108
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
109
+
110
+ time_tokens = self.time_blocks[i](time_tokens)
111
+
112
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
113
+ if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
114
+ space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
115
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
116
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
117
+
118
+ virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
119
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
120
+ point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
121
+
122
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
123
+ tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
124
+ j += 1
125
+
126
+ if self.add_space_attn:
127
+ tokens = tokens[:, : N - self.num_virtual_tracks]
128
+
129
+ tokens = tokens + init_tokens
130
+
131
+ # Apply output LayerNorm before final projection
132
+ tokens = self.output_norm(tokens)
133
+ flow = self.flow_head(tokens)
134
+
135
+ return flow, None
136
+
137
+
138
+ class CorrBlock:
139
+ def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
140
+ """
141
+ Build a pyramid of feature maps from the input.
142
+
143
+ fmaps: Tensor (B, S, C, H, W)
144
+ num_levels: number of pyramid levels (each downsampled by factor 2)
145
+ radius: search radius for sampling correlation
146
+ multiple_track_feats: if True, split the target features per pyramid level
147
+ padding_mode: passed to grid_sample / bilinear_sampler
148
+ """
149
+ B, S, C, H, W = fmaps.shape
150
+ self.S, self.C, self.H, self.W = S, C, H, W
151
+ self.num_levels = num_levels
152
+ self.radius = radius
153
+ self.padding_mode = padding_mode
154
+ self.multiple_track_feats = multiple_track_feats
155
+
156
+ # Build pyramid: each level is half the spatial resolution of the previous
157
+ self.fmaps_pyramid = [fmaps] # level 0 is full resolution
158
+ current_fmaps = fmaps
159
+ for i in range(num_levels - 1):
160
+ B, S, C, H, W = current_fmaps.shape
161
+ # Merge batch & sequence dimensions
162
+ current_fmaps = current_fmaps.reshape(B * S, C, H, W)
163
+ # Avg pool down by factor 2
164
+ current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
165
+ _, _, H_new, W_new = current_fmaps.shape
166
+ current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)
167
+ self.fmaps_pyramid.append(current_fmaps)
168
+
169
+ # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.
170
+ # This grid is added to the (scaled) coordinate centroids.
171
+ r = self.radius
172
+ dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
173
+ dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
174
+ # delta: for every (dy,dx) displacement (i.e. Δx, Δy)
175
+ self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2)
176
+
177
+ def corr_sample(self, targets, coords):
178
+ """
179
+ Instead of storing the entire correlation pyramid, we compute each level's correlation
180
+ volume, sample it immediately, then discard it. This saves GPU memory.
181
+
182
+ Args:
183
+ targets: Tensor (B, S, N, C) — features for the current targets.
184
+ coords: Tensor (B, S, N, 2) — coordinates at full resolution.
185
+
186
+ Returns:
187
+ Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
188
+ """
189
+ B, S, N, C = targets.shape
190
+
191
+ # If you have multiple track features, split them per level.
192
+ if self.multiple_track_feats:
193
+ targets_split = torch.split(targets, C // self.num_levels, dim=-1)
194
+
195
+ out_pyramid = []
196
+ for i, fmaps in enumerate(self.fmaps_pyramid):
197
+ # Get current spatial resolution H, W for this pyramid level.
198
+ B, S, C, H, W = fmaps.shape
199
+ # Reshape feature maps for correlation computation:
200
+ # fmap2s: (B, S, C, H*W)
201
+ fmap2s = fmaps.view(B, S, C, H * W)
202
+ # Choose appropriate target features.
203
+ fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C)
204
+
205
+ # Compute correlation directly
206
+ corrs = compute_corr_level(fmap1, fmap2s, C)
207
+ corrs = corrs.view(B, S, N, H, W)
208
+
209
+ # Prepare sampling grid:
210
+ # Scale down the coordinates for the current level.
211
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
212
+ # Make sure our precomputed delta grid is on the same device/dtype.
213
+ delta_lvl = self.delta.to(coords.device).to(coords.dtype)
214
+ # Now the grid for grid_sample is:
215
+ # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid)
216
+ coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
217
+
218
+ # Sample from the correlation volume using bilinear interpolation.
219
+ # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
220
+ corrs_sampled = bilinear_sampler(
221
+ corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode
222
+ )
223
+ # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
224
+ corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2)
225
+ out_pyramid.append(corrs_sampled)
226
+
227
+ # Concatenate all levels along the last dimension.
228
+ out = torch.cat(out_pyramid, dim=-1).contiguous()
229
+ return out
230
+
231
+
232
+ def compute_corr_level(fmap1, fmap2s, C):
233
+ # fmap1: (B, S, N, C)
234
+ # fmap2s: (B, S, C, H*W)
235
+ corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W)
236
+ corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W)
237
+ return corrs / math.sqrt(C)
streamvggt/heads/track_modules/modules.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from functools import partial
5
+ from typing import Callable
6
+ import collections
7
+ from torch import Tensor
8
+ from itertools import repeat
9
+
10
+
11
+ # From PyTorch internals
12
+ def _ntuple(n):
13
+ def parse(x):
14
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
15
+ return tuple(x)
16
+ return tuple(repeat(x, n))
17
+
18
+ return parse
19
+
20
+
21
+ def exists(val):
22
+ return val is not None
23
+
24
+
25
+ def default(val, d):
26
+ return val if exists(val) else d
27
+
28
+
29
+ to_2tuple = _ntuple(2)
30
+
31
+
32
+ class ResidualBlock(nn.Module):
33
+ """
34
+ ResidualBlock: construct a block of two conv layers with residual connections
35
+ """
36
+
37
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
38
+ super(ResidualBlock, self).__init__()
39
+
40
+ self.conv1 = nn.Conv2d(
41
+ in_planes,
42
+ planes,
43
+ kernel_size=kernel_size,
44
+ padding=1,
45
+ stride=stride,
46
+ padding_mode="zeros",
47
+ )
48
+ self.conv2 = nn.Conv2d(
49
+ planes,
50
+ planes,
51
+ kernel_size=kernel_size,
52
+ padding=1,
53
+ padding_mode="zeros",
54
+ )
55
+ self.relu = nn.ReLU(inplace=True)
56
+
57
+ num_groups = planes // 8
58
+
59
+ if norm_fn == "group":
60
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
61
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
62
+ if not stride == 1:
63
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
64
+
65
+ elif norm_fn == "batch":
66
+ self.norm1 = nn.BatchNorm2d(planes)
67
+ self.norm2 = nn.BatchNorm2d(planes)
68
+ if not stride == 1:
69
+ self.norm3 = nn.BatchNorm2d(planes)
70
+
71
+ elif norm_fn == "instance":
72
+ self.norm1 = nn.InstanceNorm2d(planes)
73
+ self.norm2 = nn.InstanceNorm2d(planes)
74
+ if not stride == 1:
75
+ self.norm3 = nn.InstanceNorm2d(planes)
76
+
77
+ elif norm_fn == "none":
78
+ self.norm1 = nn.Sequential()
79
+ self.norm2 = nn.Sequential()
80
+ if not stride == 1:
81
+ self.norm3 = nn.Sequential()
82
+ else:
83
+ raise NotImplementedError
84
+
85
+ if stride == 1:
86
+ self.downsample = None
87
+ else:
88
+ self.downsample = nn.Sequential(
89
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
90
+ self.norm3,
91
+ )
92
+
93
+ def forward(self, x):
94
+ y = x
95
+ y = self.relu(self.norm1(self.conv1(y)))
96
+ y = self.relu(self.norm2(self.conv2(y)))
97
+
98
+ if self.downsample is not None:
99
+ x = self.downsample(x)
100
+
101
+ return self.relu(x + y)
102
+
103
+
104
+ class Mlp(nn.Module):
105
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
106
+
107
+ def __init__(
108
+ self,
109
+ in_features,
110
+ hidden_features=None,
111
+ out_features=None,
112
+ act_layer=nn.GELU,
113
+ norm_layer=None,
114
+ bias=True,
115
+ drop=0.0,
116
+ use_conv=False,
117
+ ):
118
+ super().__init__()
119
+ out_features = out_features or in_features
120
+ hidden_features = hidden_features or in_features
121
+ bias = to_2tuple(bias)
122
+ drop_probs = to_2tuple(drop)
123
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
124
+
125
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
126
+ self.act = act_layer()
127
+ self.drop1 = nn.Dropout(drop_probs[0])
128
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
129
+ self.drop2 = nn.Dropout(drop_probs[1])
130
+
131
+ def forward(self, x):
132
+ x = self.fc1(x)
133
+ x = self.act(x)
134
+ x = self.drop1(x)
135
+ x = self.fc2(x)
136
+ x = self.drop2(x)
137
+ return x
138
+
139
+
140
+ class AttnBlock(nn.Module):
141
+ def __init__(
142
+ self,
143
+ hidden_size,
144
+ num_heads,
145
+ attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
146
+ mlp_ratio=4.0,
147
+ **block_kwargs
148
+ ):
149
+ """
150
+ Self attention block
151
+ """
152
+ super().__init__()
153
+
154
+ self.norm1 = nn.LayerNorm(hidden_size)
155
+ self.norm2 = nn.LayerNorm(hidden_size)
156
+
157
+ self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
158
+
159
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
160
+
161
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
162
+
163
+ def forward(self, x, mask=None):
164
+ # Prepare the mask for PyTorch's attention (it expects a different format)
165
+ # attn_mask = mask if mask is not None else None
166
+ # Normalize before attention
167
+ x = self.norm1(x)
168
+
169
+ # PyTorch's MultiheadAttention returns attn_output, attn_output_weights
170
+ # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
171
+
172
+ attn_output, _ = self.attn(x, x, x)
173
+
174
+ # Add & Norm
175
+ x = x + attn_output
176
+ x = x + self.mlp(self.norm2(x))
177
+ return x
178
+
179
+
180
+ class CrossAttnBlock(nn.Module):
181
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
182
+ """
183
+ Cross attention block
184
+ """
185
+ super().__init__()
186
+
187
+ self.norm1 = nn.LayerNorm(hidden_size)
188
+ self.norm_context = nn.LayerNorm(hidden_size)
189
+ self.norm2 = nn.LayerNorm(hidden_size)
190
+
191
+ self.cross_attn = nn.MultiheadAttention(
192
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
193
+ )
194
+
195
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
196
+
197
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
198
+
199
+ def forward(self, x, context, mask=None):
200
+ # Normalize inputs
201
+ x = self.norm1(x)
202
+ context = self.norm_context(context)
203
+
204
+ # Apply cross attention
205
+ # Note: nn.MultiheadAttention returns attn_output, attn_output_weights
206
+ attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
207
+
208
+ # Add & Norm
209
+ x = x + attn_output
210
+ x = x + self.mlp(self.norm2(x))
211
+ return x
streamvggt/heads/track_modules/utils.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from typing import Optional, Tuple, Union
6
+
7
+
8
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
9
+ """
10
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
11
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
12
+ Args:
13
+ - embed_dim: The embedding dimension.
14
+ - grid_size: The grid size.
15
+ Returns:
16
+ - pos_embed: The generated 2D positional embedding.
17
+ """
18
+ if isinstance(grid_size, tuple):
19
+ grid_size_h, grid_size_w = grid_size
20
+ else:
21
+ grid_size_h = grid_size_w = grid_size
22
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
23
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
24
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
25
+ grid = torch.stack(grid, dim=0)
26
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
27
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
28
+ if return_grid:
29
+ return (
30
+ pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2),
31
+ grid,
32
+ )
33
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
34
+
35
+
36
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
37
+ """
38
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
39
+
40
+ Args:
41
+ - embed_dim: The embedding dimension.
42
+ - grid: The grid to generate the embedding from.
43
+
44
+ Returns:
45
+ - emb: The generated 2D positional embedding.
46
+ """
47
+ assert embed_dim % 2 == 0
48
+
49
+ # use half of dimensions to encode grid_h
50
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
51
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
52
+
53
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
54
+ return emb
55
+
56
+
57
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
58
+ """
59
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
60
+
61
+ Args:
62
+ - embed_dim: The embedding dimension.
63
+ - pos: The position to generate the embedding from.
64
+
65
+ Returns:
66
+ - emb: The generated 1D positional embedding.
67
+ """
68
+ assert embed_dim % 2 == 0
69
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
70
+ omega /= embed_dim / 2.0
71
+ omega = 1.0 / 10000**omega # (D/2,)
72
+
73
+ pos = pos.reshape(-1) # (M,)
74
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
75
+
76
+ emb_sin = torch.sin(out) # (M, D/2)
77
+ emb_cos = torch.cos(out) # (M, D/2)
78
+
79
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
80
+ return emb[None].float()
81
+
82
+
83
+ def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
84
+ """
85
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
86
+
87
+ Args:
88
+ - xy: The coordinates to generate the embedding from.
89
+ - C: The size of the embedding.
90
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
91
+
92
+ Returns:
93
+ - pe: The generated 2D positional embedding.
94
+ """
95
+ B, N, D = xy.shape
96
+ assert D == 2
97
+
98
+ x = xy[:, :, 0:1]
99
+ y = xy[:, :, 1:2]
100
+ div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
101
+
102
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
103
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
104
+
105
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
106
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
107
+
108
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
109
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
110
+
111
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
112
+ if cat_coords:
113
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
114
+ return pe
115
+
116
+
117
+ def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
118
+ r"""Sample a tensor using bilinear interpolation
119
+
120
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
121
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
122
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
123
+ convention.
124
+
125
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
126
+ :math:`B` is the batch size, :math:`C` is the number of channels,
127
+ :math:`H` is the height of the image, and :math:`W` is the width of the
128
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
129
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
130
+
131
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
132
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
133
+ that in this case the order of the components is slightly different
134
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
135
+
136
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
137
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
138
+ left-most image pixel :math:`W-1` to the center of the right-most
139
+ pixel.
140
+
141
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
142
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
143
+ the left-most pixel :math:`W` to the right edge of the right-most
144
+ pixel.
145
+
146
+ Similar conventions apply to the :math:`y` for the range
147
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
148
+ :math:`[0,T-1]` and :math:`[0,T]`.
149
+
150
+ Args:
151
+ input (Tensor): batch of input images.
152
+ coords (Tensor): batch of coordinates.
153
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
154
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
155
+
156
+ Returns:
157
+ Tensor: sampled points.
158
+ """
159
+ coords = coords.detach().clone()
160
+ ############################################################
161
+ # IMPORTANT:
162
+ coords = coords.to(input.device).to(input.dtype)
163
+ ############################################################
164
+
165
+ sizes = input.shape[2:]
166
+
167
+ assert len(sizes) in [2, 3]
168
+
169
+ if len(sizes) == 3:
170
+ # t x y -> x y t to match dimensions T H W in grid_sample
171
+ coords = coords[..., [1, 2, 0]]
172
+
173
+ if align_corners:
174
+ scale = torch.tensor(
175
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype
176
+ )
177
+ else:
178
+ scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)
179
+
180
+ coords.mul_(scale) # coords = coords * scale
181
+ coords.sub_(1) # coords = coords - 1
182
+
183
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
184
+
185
+
186
+ def sample_features4d(input, coords):
187
+ r"""Sample spatial features
188
+
189
+ `sample_features4d(input, coords)` samples the spatial features
190
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
191
+
192
+ The field is sampled at coordinates :attr:`coords` using bilinear
193
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
194
+ 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
195
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
196
+
197
+ The output tensor has one feature per point, and has shape :math:`(B,
198
+ R, C)`.
199
+
200
+ Args:
201
+ input (Tensor): spatial features.
202
+ coords (Tensor): points.
203
+
204
+ Returns:
205
+ Tensor: sampled features.
206
+ """
207
+
208
+ B, _, _, _ = input.shape
209
+
210
+ # B R 2 -> B R 1 2
211
+ coords = coords.unsqueeze(2)
212
+
213
+ # B C R 1
214
+ feats = bilinear_sampler(input, coords)
215
+
216
+ return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
vggt/heads/__pycache__/utils.cpython-310.pyc → streamvggt/heads/utils.py RENAMED
Binary files a/vggt/heads/__pycache__/utils.cpython-310.pyc and b/streamvggt/heads/utils.py differ
 
streamvggt/layers/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .mlp import Mlp
2
+ from .patch_embed import PatchEmbed
3
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
4
+ from .block import NestedTensorBlock
5
+ from .attention import MemEffAttention
streamvggt/layers/attention.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import warnings
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ from typing import Union, Tuple, Dict, Optional
10
+
11
+ from einops import rearrange
12
+
13
+ XFORMERS_AVAILABLE = False
14
+
15
+
16
+ class Attention(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ num_heads: int = 8,
21
+ qkv_bias: bool = True,
22
+ proj_bias: bool = True,
23
+ attn_drop: float = 0.0,
24
+ proj_drop: float = 0.0,
25
+ norm_layer: nn.Module = nn.LayerNorm,
26
+ qk_norm: bool = False,
27
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
28
+ rope=None,
29
+ ) -> None:
30
+ super().__init__()
31
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
32
+ self.num_heads = num_heads
33
+ self.head_dim = dim // num_heads
34
+ self.scale = self.head_dim**-0.5
35
+ self.fused_attn = fused_attn
36
+
37
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
38
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
39
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
40
+ self.attn_drop = nn.Dropout(attn_drop)
41
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
42
+ self.proj_drop = nn.Dropout(proj_drop)
43
+ self.rope = rope
44
+
45
+ def forward(self,
46
+ x: torch.Tensor,
47
+ pos=None,
48
+ attn_mask=None,
49
+ past_key_values=None,
50
+ use_cache=False
51
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple]]:
52
+ B, N, C = x.shape
53
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
54
+ q, k, v = qkv.unbind(0)
55
+
56
+ pos_k = pos
57
+ if use_cache:
58
+ k = k.unsqueeze(2)
59
+ v = v.unsqueeze(2)
60
+ if past_key_values is not None:
61
+ past_k, past_v = past_key_values
62
+ k = torch.cat([past_k, k], dim=2)
63
+ v = torch.cat([past_v, v], dim=2)
64
+
65
+ new_kv = (k, v)
66
+ a, b, c, d, e = k.shape
67
+ k = k.reshape(a, b, c*d, e)
68
+ v = v.reshape(a, b, c*d, e)
69
+ if pos_k is not None:
70
+ #print(pos_k.shape)
71
+ pos_k = pos_k.repeat(1, c, 1)
72
+ #print(pos_k.shape)
73
+
74
+ q, k = self.q_norm(q), self.k_norm(k)
75
+
76
+ if self.rope is not None:
77
+ q = self.rope(q, pos)
78
+ k = self.rope(k, pos_k)
79
+
80
+ if self.fused_attn:
81
+ x = F.scaled_dot_product_attention(
82
+ q,
83
+ k,
84
+ v,
85
+ attn_mask=attn_mask,
86
+ dropout_p=self.attn_drop.p if self.training else 0.0,
87
+ )
88
+
89
+ else:
90
+ q = q * self.scale
91
+ attn = q @ k.transpose(-2, -1)
92
+
93
+ # Mask
94
+ if attn_mask is not None:
95
+ assert attn_mask.shape[-2:] == (N, N), f"Expected mask shape [..., {N}, {N}], got {attn_mask.shape}"
96
+ attn = attn + attn_mask
97
+
98
+ attn = attn.softmax(dim=-1)
99
+ attn = self.attn_drop(attn)
100
+ x = attn @ v
101
+
102
+ x = x.transpose(1, 2).reshape(B, N, C)
103
+ x = self.proj(x)
104
+ x = self.proj_drop(x)
105
+ if use_cache:
106
+ return x, new_kv
107
+ return x
108
+
109
+
110
+ class MemEffAttention(Attention):
111
+ def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
112
+ assert pos is None
113
+ if not XFORMERS_AVAILABLE:
114
+ if attn_bias is not None:
115
+ raise AssertionError("xFormers is required for using nested tensors")
116
+ return super().forward(x)
117
+
118
+ B, N, C = x.shape
119
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
120
+
121
+ q, k, v = unbind(qkv, 2)
122
+
123
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
124
+ x = x.reshape([B, N, C])
125
+
126
+ x = self.proj(x)
127
+ x = self.proj_drop(x)
128
+
129
+ return x
streamvggt/layers/block.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Callable, List, Any, Tuple, Dict, Union
4
+ import warnings
5
+
6
+ import torch
7
+ from torch import nn, Tensor
8
+
9
+ from .attention import Attention
10
+ from .drop_path import DropPath
11
+ from .layer_scale import LayerScale
12
+ from .mlp import Mlp
13
+
14
+
15
+ XFORMERS_AVAILABLE = False
16
+
17
+
18
+ class Block(nn.Module):
19
+ def __init__(
20
+ self,
21
+ dim: int,
22
+ num_heads: int,
23
+ mlp_ratio: float = 4.0,
24
+ qkv_bias: bool = True,
25
+ proj_bias: bool = True,
26
+ ffn_bias: bool = True,
27
+ drop: float = 0.0,
28
+ attn_drop: float = 0.0,
29
+ init_values=None,
30
+ drop_path: float = 0.0,
31
+ act_layer: Callable[..., nn.Module] = nn.GELU,
32
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
33
+ attn_class: Callable[..., nn.Module] = Attention,
34
+ ffn_layer: Callable[..., nn.Module] = Mlp,
35
+ qk_norm: bool = False,
36
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
37
+ rope=None,
38
+ ) -> None:
39
+ super().__init__()
40
+
41
+ self.norm1 = norm_layer(dim)
42
+
43
+ self.attn = attn_class(
44
+ dim,
45
+ num_heads=num_heads,
46
+ qkv_bias=qkv_bias,
47
+ proj_bias=proj_bias,
48
+ attn_drop=attn_drop,
49
+ proj_drop=drop,
50
+ qk_norm=qk_norm,
51
+ fused_attn=fused_attn,
52
+ rope=rope,
53
+ )
54
+
55
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
56
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
57
+
58
+ self.norm2 = norm_layer(dim)
59
+ mlp_hidden_dim = int(dim * mlp_ratio)
60
+ self.mlp = ffn_layer(
61
+ in_features=dim,
62
+ hidden_features=mlp_hidden_dim,
63
+ act_layer=act_layer,
64
+ drop=drop,
65
+ bias=ffn_bias,
66
+ )
67
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
68
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
69
+
70
+ self.sample_drop_ratio = drop_path
71
+
72
+ def forward(self, x: Tensor, pos=None, attn_mask=None, past_key_values=None, use_cache=False) -> Union[Tensor, Tuple[Tensor, Dict]]:
73
+
74
+ def attn_residual_func(x: Tensor, pos=None, attn_mask=None, past_key_values=None, use_cache=False) -> Union[Tensor, Tuple[Tensor, Dict]]:
75
+ if use_cache:
76
+ output, new_kv = self.attn(self.norm1(x), pos=pos, past_key_values=past_key_values, use_cache=True)
77
+ return self.ls1(output), new_kv
78
+ else:
79
+ if attn_mask is not None:
80
+ return self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask))
81
+ else:
82
+ return self.ls1(self.attn(self.norm1(x), pos=pos))
83
+ def ffn_residual_func(x: Tensor) -> Tensor:
84
+ return self.ls2(self.mlp(self.norm2(x)))
85
+
86
+ if use_cache:
87
+ attn_output, new_kv = attn_residual_func(x, pos=pos, past_key_values=past_key_values, use_cache=True)
88
+ x = x + attn_output
89
+ x = x + ffn_residual_func(x)
90
+ return x, new_kv
91
+
92
+ if self.training and self.sample_drop_ratio > 0.1:
93
+ # the overhead is compensated only for a drop path rate larger than 0.1
94
+ x = drop_add_residual_stochastic_depth(
95
+ x,
96
+ pos=pos,
97
+ residual_func=attn_residual_func,
98
+ sample_drop_ratio=self.sample_drop_ratio,
99
+ )
100
+ x = drop_add_residual_stochastic_depth(
101
+ x,
102
+ residual_func=ffn_residual_func,
103
+ sample_drop_ratio=self.sample_drop_ratio,
104
+ )
105
+ elif self.training and self.sample_drop_ratio > 0.0:
106
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos, attn_mask=attn_mask))
107
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
108
+ else:
109
+ x = x + attn_residual_func(x, pos=pos, attn_mask=attn_mask)
110
+ x = x + ffn_residual_func(x)
111
+ return x
112
+
113
+
114
+ def drop_add_residual_stochastic_depth(
115
+ x: Tensor,
116
+ residual_func: Callable[[Tensor], Tensor],
117
+ sample_drop_ratio: float = 0.0,
118
+ pos=None,
119
+ ) -> Tensor:
120
+ # 1) extract subset using permutation
121
+ b, n, d = x.shape
122
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
123
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
124
+ x_subset = x[brange]
125
+
126
+ # 2) apply residual_func to get residual
127
+ if pos is not None:
128
+ # if necessary, apply rope to the subset
129
+ pos = pos[brange]
130
+ residual = residual_func(x_subset, pos=pos)
131
+ else:
132
+ residual = residual_func(x_subset)
133
+
134
+ x_flat = x.flatten(1)
135
+ residual = residual.flatten(1)
136
+
137
+ residual_scale_factor = b / sample_subset_size
138
+
139
+ # 3) add the residual
140
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
141
+ return x_plus_residual.view_as(x)
142
+
143
+
144
+ def get_branges_scales(x, sample_drop_ratio=0.0):
145
+ b, n, d = x.shape
146
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
147
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
148
+ residual_scale_factor = b / sample_subset_size
149
+ return brange, residual_scale_factor
150
+
151
+
152
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
153
+ if scaling_vector is None:
154
+ x_flat = x.flatten(1)
155
+ residual = residual.flatten(1)
156
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
157
+ else:
158
+ x_plus_residual = scaled_index_add(
159
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
160
+ )
161
+ return x_plus_residual
162
+
163
+
164
+ attn_bias_cache: Dict[Tuple, Any] = {}
165
+
166
+
167
+ def get_attn_bias_and_cat(x_list, branges=None):
168
+ """
169
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
170
+ """
171
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
172
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
173
+ if all_shapes not in attn_bias_cache.keys():
174
+ seqlens = []
175
+ for b, x in zip(batch_sizes, x_list):
176
+ for _ in range(b):
177
+ seqlens.append(x.shape[1])
178
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
179
+ attn_bias._batch_sizes = batch_sizes
180
+ attn_bias_cache[all_shapes] = attn_bias
181
+
182
+ if branges is not None:
183
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
184
+ else:
185
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
186
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
187
+
188
+ return attn_bias_cache[all_shapes], cat_tensors
189
+
190
+
191
+ def drop_add_residual_stochastic_depth_list(
192
+ x_list: List[Tensor],
193
+ residual_func: Callable[[Tensor, Any], Tensor],
194
+ sample_drop_ratio: float = 0.0,
195
+ scaling_vector=None,
196
+ ) -> Tensor:
197
+ # 1) generate random set of indices for dropping samples in the batch
198
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
199
+ branges = [s[0] for s in branges_scales]
200
+ residual_scale_factors = [s[1] for s in branges_scales]
201
+
202
+ # 2) get attention bias and index+concat the tensors
203
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
204
+
205
+ # 3) apply residual_func to get residual, and split the result
206
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
207
+
208
+ outputs = []
209
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
210
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
211
+ return outputs
212
+
213
+
214
+ class NestedTensorBlock(Block):
215
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
216
+ """
217
+ x_list contains a list of tensors to nest together and run
218
+ """
219
+ assert isinstance(self.attn, MemEffAttention)
220
+
221
+ if self.training and self.sample_drop_ratio > 0.0:
222
+
223
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
224
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
225
+
226
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
227
+ return self.mlp(self.norm2(x))
228
+
229
+ x_list = drop_add_residual_stochastic_depth_list(
230
+ x_list,
231
+ residual_func=attn_residual_func,
232
+ sample_drop_ratio=self.sample_drop_ratio,
233
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
234
+ )
235
+ x_list = drop_add_residual_stochastic_depth_list(
236
+ x_list,
237
+ residual_func=ffn_residual_func,
238
+ sample_drop_ratio=self.sample_drop_ratio,
239
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
240
+ )
241
+ return x_list
242
+ else:
243
+
244
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
245
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
246
+
247
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
248
+ return self.ls2(self.mlp(self.norm2(x)))
249
+
250
+ attn_bias, x = get_attn_bias_and_cat(x_list)
251
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
252
+ x = x + ffn_residual_func(x)
253
+ return attn_bias.split(x)
254
+
255
+ def forward(self, x_or_x_list):
256
+ if isinstance(x_or_x_list, Tensor):
257
+ return super().forward(x_or_x_list)
258
+ elif isinstance(x_or_x_list, list):
259
+ if not XFORMERS_AVAILABLE:
260
+ raise AssertionError("xFormers is required for using nested tensors")
261
+ return self.forward_nested(x_or_x_list)
262
+ else:
263
+ raise AssertionError
streamvggt/layers/drop_path.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
5
+ if drop_prob == 0.0 or not training:
6
+ return x
7
+ keep_prob = 1 - drop_prob
8
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
9
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
10
+ if keep_prob > 0.0:
11
+ random_tensor.div_(keep_prob)
12
+ output = x * random_tensor
13
+ return output
14
+
15
+
16
+ class DropPath(nn.Module):
17
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
18
+
19
+ def __init__(self, drop_prob=None):
20
+ super(DropPath, self).__init__()
21
+ self.drop_prob = drop_prob
22
+
23
+ def forward(self, x):
24
+ return drop_path(x, self.drop_prob, self.training)
streamvggt/layers/layer_scale.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch import nn
6
+
7
+
8
+ class LayerScale(nn.Module):
9
+ def __init__(
10
+ self,
11
+ dim: int,
12
+ init_values: Union[float, Tensor] = 1e-5,
13
+ inplace: bool = False,
14
+ ) -> None:
15
+ super().__init__()
16
+ self.inplace = inplace
17
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
18
+
19
+ def forward(self, x: Tensor) -> Tensor:
20
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
streamvggt/layers/mlp.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional
2
+
3
+ from torch import Tensor, nn
4
+
5
+
6
+ class Mlp(nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_features: int,
10
+ hidden_features: Optional[int] = None,
11
+ out_features: Optional[int] = None,
12
+ act_layer: Callable[..., nn.Module] = nn.GELU,
13
+ drop: float = 0.0,
14
+ bias: bool = True,
15
+ ) -> None:
16
+ super().__init__()
17
+ out_features = out_features or in_features
18
+ hidden_features = hidden_features or in_features
19
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
20
+ self.act = act_layer()
21
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
22
+ self.drop = nn.Dropout(drop)
23
+
24
+ def forward(self, x: Tensor) -> Tensor:
25
+ x = self.fc1(x)
26
+ x = self.act(x)
27
+ x = self.drop(x)
28
+ x = self.fc2(x)
29
+ x = self.drop(x)
30
+ return x
streamvggt/layers/patch_embed.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple, Union
2
+
3
+ from torch import Tensor
4
+ import torch.nn as nn
5
+
6
+
7
+ def make_2tuple(x):
8
+ if isinstance(x, tuple):
9
+ assert len(x) == 2
10
+ return x
11
+
12
+ assert isinstance(x, int)
13
+ return (x, x)
14
+
15
+
16
+ class PatchEmbed(nn.Module):
17
+ """
18
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
19
+
20
+ Args:
21
+ img_size: Image size.
22
+ patch_size: Patch token size.
23
+ in_chans: Number of input image channels.
24
+ embed_dim: Number of linear projection output channels.
25
+ norm_layer: Normalization layer.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ img_size: Union[int, Tuple[int, int]] = 224,
31
+ patch_size: Union[int, Tuple[int, int]] = 16,
32
+ in_chans: int = 3,
33
+ embed_dim: int = 768,
34
+ norm_layer: Optional[Callable] = None,
35
+ flatten_embedding: bool = True,
36
+ ) -> None:
37
+ super().__init__()
38
+
39
+ image_HW = make_2tuple(img_size)
40
+ patch_HW = make_2tuple(patch_size)
41
+ patch_grid_size = (
42
+ image_HW[0] // patch_HW[0],
43
+ image_HW[1] // patch_HW[1],
44
+ )
45
+
46
+ self.img_size = image_HW
47
+ self.patch_size = patch_HW
48
+ self.patches_resolution = patch_grid_size
49
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
50
+
51
+ self.in_chans = in_chans
52
+ self.embed_dim = embed_dim
53
+
54
+ self.flatten_embedding = flatten_embedding
55
+
56
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
57
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
58
+
59
+ def forward(self, x: Tensor) -> Tensor:
60
+ _, _, H, W = x.shape
61
+ patch_H, patch_W = self.patch_size
62
+
63
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
64
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
65
+
66
+ x = self.proj(x) # B C H W
67
+ H, W = x.size(2), x.size(3)
68
+ x = x.flatten(2).transpose(1, 2) # B HW C
69
+ x = self.norm(x)
70
+ if not self.flatten_embedding:
71
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
72
+ return x
73
+
74
+ def flops(self) -> float:
75
+ Ho, Wo = self.patches_resolution
76
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
77
+ if self.norm is not None:
78
+ flops += Ho * Wo * self.embed_dim
79
+ return flops
vggt/layers/__pycache__/rope.cpython-310.pyc → streamvggt/layers/rope.py RENAMED
Binary files a/vggt/layers/__pycache__/rope.cpython-310.pyc and b/streamvggt/layers/rope.py differ
 
streamvggt/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Callable, Optional
3
+ import warnings
4
+
5
+ from torch import Tensor, nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class SwiGLUFFN(nn.Module):
10
+ def __init__(
11
+ self,
12
+ in_features: int,
13
+ hidden_features: Optional[int] = None,
14
+ out_features: Optional[int] = None,
15
+ act_layer: Callable[..., nn.Module] = None,
16
+ drop: float = 0.0,
17
+ bias: bool = True,
18
+ ) -> None:
19
+ super().__init__()
20
+ out_features = out_features or in_features
21
+ hidden_features = hidden_features or in_features
22
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
23
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
24
+
25
+ def forward(self, x: Tensor) -> Tensor:
26
+ x12 = self.w12(x)
27
+ x1, x2 = x12.chunk(2, dim=-1)
28
+ hidden = F.silu(x1) * x2
29
+ return self.w3(hidden)
30
+
31
+
32
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
33
+ # try:
34
+ # if XFORMERS_ENABLED:
35
+ # from xformers.ops import SwiGLU
36
+
37
+ # XFORMERS_AVAILABLE = True
38
+ # warnings.warn("xFormers is available (SwiGLU)")
39
+ # else:
40
+ # warnings.warn("xFormers is disabled (SwiGLU)")
41
+ # raise ImportError
42
+ # except ImportError:
43
+ SwiGLU = SwiGLUFFN
44
+ XFORMERS_AVAILABLE = False
45
+
46
+ # warnings.warn("xFormers is not available (SwiGLU)")
47
+
48
+
49
+ class SwiGLUFFNFused(SwiGLU):
50
+ def __init__(
51
+ self,
52
+ in_features: int,
53
+ hidden_features: Optional[int] = None,
54
+ out_features: Optional[int] = None,
55
+ act_layer: Callable[..., nn.Module] = None,
56
+ drop: float = 0.0,
57
+ bias: bool = True,
58
+ ) -> None:
59
+ out_features = out_features or in_features
60
+ hidden_features = hidden_features or in_features
61
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
62
+ super().__init__(
63
+ in_features=in_features,
64
+ hidden_features=hidden_features,
65
+ out_features=out_features,
66
+ bias=bias,
67
+ )
streamvggt/layers/vision_transformer.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import math
3
+ import logging
4
+ from typing import Sequence, Tuple, Union, Callable
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.checkpoint import checkpoint
9
+ from torch.nn.init import trunc_normal_
10
+ from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
11
+
12
+ logger = logging.getLogger("dinov2")
13
+
14
+
15
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
16
+ if not depth_first and include_root:
17
+ fn(module=module, name=name)
18
+ for child_name, child_module in module.named_children():
19
+ child_name = ".".join((name, child_name)) if name else child_name
20
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
21
+ if depth_first and include_root:
22
+ fn(module=module, name=name)
23
+ return module
24
+
25
+
26
+ class BlockChunk(nn.ModuleList):
27
+ def forward(self, x):
28
+ for b in self:
29
+ x = b(x)
30
+ return x
31
+
32
+
33
+ class DinoVisionTransformer(nn.Module):
34
+ def __init__(
35
+ self,
36
+ img_size=224,
37
+ patch_size=16,
38
+ in_chans=3,
39
+ embed_dim=768,
40
+ depth=12,
41
+ num_heads=12,
42
+ mlp_ratio=4.0,
43
+ qkv_bias=True,
44
+ ffn_bias=True,
45
+ proj_bias=True,
46
+ drop_path_rate=0.0,
47
+ drop_path_uniform=False,
48
+ init_values=None, # for layerscale: None or 0 => no layerscale
49
+ embed_layer=PatchEmbed,
50
+ act_layer=nn.GELU,
51
+ block_fn=Block,
52
+ ffn_layer="mlp",
53
+ block_chunks=1,
54
+ num_register_tokens=0,
55
+ interpolate_antialias=False,
56
+ interpolate_offset=0.1,
57
+ qk_norm=False,
58
+ ):
59
+ """
60
+ Args:
61
+ img_size (int, tuple): input image size
62
+ patch_size (int, tuple): patch size
63
+ in_chans (int): number of input channels
64
+ embed_dim (int): embedding dimension
65
+ depth (int): depth of transformer
66
+ num_heads (int): number of attention heads
67
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
68
+ qkv_bias (bool): enable bias for qkv if True
69
+ proj_bias (bool): enable bias for proj in attn if True
70
+ ffn_bias (bool): enable bias for ffn if True
71
+ drop_path_rate (float): stochastic depth rate
72
+ drop_path_uniform (bool): apply uniform drop rate across blocks
73
+ weight_init (str): weight init scheme
74
+ init_values (float): layer-scale init values
75
+ embed_layer (nn.Module): patch embedding layer
76
+ act_layer (nn.Module): MLP activation layer
77
+ block_fn (nn.Module): transformer block class
78
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
79
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
80
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
81
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
82
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
83
+ """
84
+ super().__init__()
85
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
86
+
87
+ # tricky but makes it work
88
+ self.use_checkpoint = False
89
+ #
90
+
91
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
92
+ self.num_tokens = 1
93
+ self.n_blocks = depth
94
+ self.num_heads = num_heads
95
+ self.patch_size = patch_size
96
+ self.num_register_tokens = num_register_tokens
97
+ self.interpolate_antialias = interpolate_antialias
98
+ self.interpolate_offset = interpolate_offset
99
+
100
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
101
+ num_patches = self.patch_embed.num_patches
102
+
103
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
104
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
105
+ assert num_register_tokens >= 0
106
+ self.register_tokens = (
107
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
108
+ )
109
+
110
+ if drop_path_uniform is True:
111
+ dpr = [drop_path_rate] * depth
112
+ else:
113
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
114
+
115
+ if ffn_layer == "mlp":
116
+ logger.info("using MLP layer as FFN")
117
+ ffn_layer = Mlp
118
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
119
+ logger.info("using SwiGLU layer as FFN")
120
+ ffn_layer = SwiGLUFFNFused
121
+ elif ffn_layer == "identity":
122
+ logger.info("using Identity layer as FFN")
123
+
124
+ def f(*args, **kwargs):
125
+ return nn.Identity()
126
+
127
+ ffn_layer = f
128
+ else:
129
+ raise NotImplementedError
130
+
131
+ blocks_list = [
132
+ block_fn(
133
+ dim=embed_dim,
134
+ num_heads=num_heads,
135
+ mlp_ratio=mlp_ratio,
136
+ qkv_bias=qkv_bias,
137
+ proj_bias=proj_bias,
138
+ ffn_bias=ffn_bias,
139
+ drop_path=dpr[i],
140
+ norm_layer=norm_layer,
141
+ act_layer=act_layer,
142
+ ffn_layer=ffn_layer,
143
+ init_values=init_values,
144
+ qk_norm=qk_norm,
145
+ )
146
+ for i in range(depth)
147
+ ]
148
+ if block_chunks > 0:
149
+ self.chunked_blocks = True
150
+ chunked_blocks = []
151
+ chunksize = depth // block_chunks
152
+ for i in range(0, depth, chunksize):
153
+ # this is to keep the block index consistent if we chunk the block list
154
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
155
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
156
+ else:
157
+ self.chunked_blocks = False
158
+ self.blocks = nn.ModuleList(blocks_list)
159
+
160
+ self.norm = norm_layer(embed_dim)
161
+ self.head = nn.Identity()
162
+
163
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
164
+
165
+ self.init_weights()
166
+
167
+ def init_weights(self):
168
+ trunc_normal_(self.pos_embed, std=0.02)
169
+ nn.init.normal_(self.cls_token, std=1e-6)
170
+ if self.register_tokens is not None:
171
+ nn.init.normal_(self.register_tokens, std=1e-6)
172
+ named_apply(init_weights_vit_timm, self)
173
+
174
+ def interpolate_pos_encoding(self, x, w, h):
175
+ previous_dtype = x.dtype
176
+ npatch = x.shape[1] - 1
177
+ N = self.pos_embed.shape[1] - 1
178
+ if npatch == N and w == h:
179
+ return self.pos_embed
180
+ pos_embed = self.pos_embed.float()
181
+ class_pos_embed = pos_embed[:, 0]
182
+ patch_pos_embed = pos_embed[:, 1:]
183
+ dim = x.shape[-1]
184
+ w0 = w // self.patch_size
185
+ h0 = h // self.patch_size
186
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
187
+ assert N == M * M
188
+ kwargs = {}
189
+ if self.interpolate_offset:
190
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
191
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
192
+ sx = float(w0 + self.interpolate_offset) / M
193
+ sy = float(h0 + self.interpolate_offset) / M
194
+ kwargs["scale_factor"] = (sx, sy)
195
+ else:
196
+ # Simply specify an output size instead of a scale factor
197
+ kwargs["size"] = (w0, h0)
198
+ patch_pos_embed = nn.functional.interpolate(
199
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
200
+ mode="bicubic",
201
+ antialias=self.interpolate_antialias,
202
+ **kwargs,
203
+ )
204
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
205
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
206
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
207
+
208
+ def prepare_tokens_with_masks(self, x, masks=None):
209
+ B, nc, w, h = x.shape
210
+ x = self.patch_embed(x)
211
+ if masks is not None:
212
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
213
+
214
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
215
+ x = x + self.interpolate_pos_encoding(x, w, h)
216
+
217
+ if self.register_tokens is not None:
218
+ x = torch.cat(
219
+ (
220
+ x[:, :1],
221
+ self.register_tokens.expand(x.shape[0], -1, -1),
222
+ x[:, 1:],
223
+ ),
224
+ dim=1,
225
+ )
226
+
227
+ return x
228
+
229
+ def forward_features_list(self, x_list, masks_list):
230
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
231
+
232
+ for blk in self.blocks:
233
+ if self.use_checkpoint:
234
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
235
+ else:
236
+ x = blk(x)
237
+
238
+ all_x = x
239
+ output = []
240
+ for x, masks in zip(all_x, masks_list):
241
+ x_norm = self.norm(x)
242
+ output.append(
243
+ {
244
+ "x_norm_clstoken": x_norm[:, 0],
245
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
246
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
247
+ "x_prenorm": x,
248
+ "masks": masks,
249
+ }
250
+ )
251
+ return output
252
+
253
+ def forward_features(self, x, masks=None):
254
+ if isinstance(x, list):
255
+ return self.forward_features_list(x, masks)
256
+
257
+ x = self.prepare_tokens_with_masks(x, masks)
258
+
259
+ for blk in self.blocks:
260
+ if self.use_checkpoint:
261
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
262
+ else:
263
+ x = blk(x)
264
+
265
+ x_norm = self.norm(x)
266
+ return {
267
+ "x_norm_clstoken": x_norm[:, 0],
268
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
269
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
270
+ "x_prenorm": x,
271
+ "masks": masks,
272
+ }
273
+
274
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
275
+ x = self.prepare_tokens_with_masks(x)
276
+ # If n is an int, take the n last blocks. If it's a list, take them
277
+ output, total_block_len = [], len(self.blocks)
278
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
279
+ for i, blk in enumerate(self.blocks):
280
+ x = blk(x)
281
+ if i in blocks_to_take:
282
+ output.append(x)
283
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
284
+ return output
285
+
286
+ def _get_intermediate_layers_chunked(self, x, n=1):
287
+ x = self.prepare_tokens_with_masks(x)
288
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
289
+ # If n is an int, take the n last blocks. If it's a list, take them
290
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
291
+ for block_chunk in self.blocks:
292
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
293
+ x = blk(x)
294
+ if i in blocks_to_take:
295
+ output.append(x)
296
+ i += 1
297
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
298
+ return output
299
+
300
+ def get_intermediate_layers(
301
+ self,
302
+ x: torch.Tensor,
303
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
304
+ reshape: bool = False,
305
+ return_class_token: bool = False,
306
+ norm=True,
307
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
308
+ if self.chunked_blocks:
309
+ outputs = self._get_intermediate_layers_chunked(x, n)
310
+ else:
311
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
312
+ if norm:
313
+ outputs = [self.norm(out) for out in outputs]
314
+ class_tokens = [out[:, 0] for out in outputs]
315
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
316
+ if reshape:
317
+ B, _, w, h = x.shape
318
+ outputs = [
319
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
320
+ for out in outputs
321
+ ]
322
+ if return_class_token:
323
+ return tuple(zip(outputs, class_tokens))
324
+ return tuple(outputs)
325
+
326
+ def forward(self, *args, is_training=True, **kwargs):
327
+ ret = self.forward_features(*args, **kwargs)
328
+ if is_training:
329
+ return ret
330
+ else:
331
+ return self.head(ret["x_norm_clstoken"])
332
+
333
+
334
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
335
+ """ViT weight initialization, original timm impl (for reproducibility)"""
336
+ if isinstance(module, nn.Linear):
337
+ trunc_normal_(module.weight, std=0.02)
338
+ if module.bias is not None:
339
+ nn.init.zeros_(module.bias)
340
+
341
+
342
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
343
+ model = DinoVisionTransformer(
344
+ patch_size=patch_size,
345
+ embed_dim=384,
346
+ depth=12,
347
+ num_heads=6,
348
+ mlp_ratio=4,
349
+ block_fn=partial(Block, attn_class=MemEffAttention),
350
+ num_register_tokens=num_register_tokens,
351
+ **kwargs,
352
+ )
353
+ return model
354
+
355
+
356
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
357
+ model = DinoVisionTransformer(
358
+ patch_size=patch_size,
359
+ embed_dim=768,
360
+ depth=12,
361
+ num_heads=12,
362
+ mlp_ratio=4,
363
+ block_fn=partial(Block, attn_class=MemEffAttention),
364
+ num_register_tokens=num_register_tokens,
365
+ **kwargs,
366
+ )
367
+ return model
368
+
369
+
370
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
371
+ model = DinoVisionTransformer(
372
+ patch_size=patch_size,
373
+ embed_dim=1024,
374
+ depth=24,
375
+ num_heads=16,
376
+ mlp_ratio=4,
377
+ block_fn=partial(Block, attn_class=MemEffAttention),
378
+ num_register_tokens=num_register_tokens,
379
+ **kwargs,
380
+ )
381
+ return model
382
+
383
+
384
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
385
+ """
386
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
387
+ """
388
+ model = DinoVisionTransformer(
389
+ patch_size=patch_size,
390
+ embed_dim=1536,
391
+ depth=40,
392
+ num_heads=24,
393
+ mlp_ratio=4,
394
+ block_fn=partial(Block, attn_class=MemEffAttention),
395
+ num_register_tokens=num_register_tokens,
396
+ **kwargs,
397
+ )
398
+ return model
streamvggt/models/aggregator.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from typing import Optional, Tuple, Union, List, Dict, Any
12
+
13
+ from streamvggt.layers import PatchEmbed
14
+ from streamvggt.layers.block import Block
15
+ from streamvggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
16
+ from streamvggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ _RESNET_MEAN = [0.485, 0.456, 0.406]
21
+ _RESNET_STD = [0.229, 0.224, 0.225]
22
+
23
+
24
+ class Aggregator(nn.Module):
25
+ """
26
+ The Aggregator applies alternating-attention over input frames,
27
+ as described in VGGT: Visual Geometry Grounded Transformer.
28
+
29
+
30
+ Args:
31
+ img_size (int): Image size in pixels.
32
+ patch_size (int): Size of each patch for PatchEmbed.
33
+ embed_dim (int): Dimension of the token embeddings.
34
+ depth (int): Number of blocks.
35
+ num_heads (int): Number of attention heads.
36
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
37
+ num_register_tokens (int): Number of register tokens.
38
+ block_fn (nn.Module): The block type used for attention (Block by default).
39
+ qkv_bias (bool): Whether to include bias in QKV projections.
40
+ proj_bias (bool): Whether to include bias in the output projection.
41
+ ffn_bias (bool): Whether to include bias in MLP layers.
42
+ patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
43
+ aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
44
+ aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
45
+ qk_norm (bool): Whether to apply QK normalization.
46
+ rope_freq (int): Base frequency for rotary embedding. -1 to disable.
47
+ init_values (float): Init scale for layer scale.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ img_size=518,
53
+ patch_size=14,
54
+ embed_dim=1024,
55
+ depth=24,
56
+ num_heads=16,
57
+ mlp_ratio=4.0,
58
+ num_register_tokens=4,
59
+ block_fn=Block,
60
+ qkv_bias=True,
61
+ proj_bias=True,
62
+ ffn_bias=True,
63
+ patch_embed="dinov2_vitl14_reg",
64
+ aa_order=["frame", "global"],
65
+ aa_block_size=1,
66
+ qk_norm=True,
67
+ rope_freq=100,
68
+ init_values=0.01,
69
+ ):
70
+ super().__init__()
71
+
72
+ self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
73
+
74
+ # Initialize rotary position embedding if frequency > 0
75
+ self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
76
+ self.position_getter = PositionGetter() if self.rope is not None else None
77
+
78
+ self.frame_blocks = nn.ModuleList(
79
+ [
80
+ block_fn(
81
+ dim=embed_dim,
82
+ num_heads=num_heads,
83
+ mlp_ratio=mlp_ratio,
84
+ qkv_bias=qkv_bias,
85
+ proj_bias=proj_bias,
86
+ ffn_bias=ffn_bias,
87
+ init_values=init_values,
88
+ qk_norm=qk_norm,
89
+ rope=self.rope,
90
+ )
91
+ for _ in range(depth)
92
+ ]
93
+ )
94
+
95
+ self.global_blocks = nn.ModuleList(
96
+ [
97
+ block_fn(
98
+ dim=embed_dim,
99
+ num_heads=num_heads,
100
+ mlp_ratio=mlp_ratio,
101
+ qkv_bias=qkv_bias,
102
+ proj_bias=proj_bias,
103
+ ffn_bias=ffn_bias,
104
+ init_values=init_values,
105
+ qk_norm=qk_norm,
106
+ rope=self.rope,
107
+ )
108
+ for _ in range(depth)
109
+ ]
110
+ )
111
+
112
+ self.depth = depth
113
+ self.aa_order = aa_order
114
+ self.patch_size = patch_size
115
+ self.aa_block_size = aa_block_size
116
+
117
+ # Validate that depth is divisible by aa_block_size
118
+ if self.depth % self.aa_block_size != 0:
119
+ raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
120
+
121
+ self.aa_block_num = self.depth // self.aa_block_size
122
+
123
+ # Note: We have two camera tokens, one for the first frame and one for the rest
124
+ # The same applies for register tokens
125
+ self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
126
+ self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
127
+
128
+ # The patch tokens start after the camera and register tokens
129
+ self.patch_start_idx = 1 + num_register_tokens
130
+
131
+ # Initialize parameters with small values
132
+ nn.init.normal_(self.camera_token, std=1e-6)
133
+ nn.init.normal_(self.register_token, std=1e-6)
134
+
135
+ # Register normalization constants as buffers
136
+ for name, value in (
137
+ ("_resnet_mean", _RESNET_MEAN),
138
+ ("_resnet_std", _RESNET_STD),
139
+ ):
140
+ self.register_buffer(
141
+ name,
142
+ torch.FloatTensor(value).reshape(1, 1, 3, 1, 1),
143
+ persistent=False,
144
+ )
145
+
146
+
147
+ def __build_patch_embed__(
148
+ self,
149
+ patch_embed,
150
+ img_size,
151
+ patch_size,
152
+ num_register_tokens,
153
+ interpolate_antialias=True,
154
+ interpolate_offset=0.0,
155
+ block_chunks=0,
156
+ init_values=1.0,
157
+ embed_dim=1024,
158
+ ):
159
+ """
160
+ Build the patch embed layer. If 'conv', we use a
161
+ simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
162
+ """
163
+
164
+ if "conv" in patch_embed:
165
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
166
+ else:
167
+ vit_models = {
168
+ "dinov2_vitl14_reg": vit_large,
169
+ "dinov2_vitb14_reg": vit_base,
170
+ "dinov2_vits14_reg": vit_small,
171
+ "dinov2_vitg2_reg": vit_giant2,
172
+ }
173
+
174
+ self.patch_embed = vit_models[patch_embed](
175
+ img_size=img_size,
176
+ patch_size=patch_size,
177
+ num_register_tokens=num_register_tokens,
178
+ interpolate_antialias=interpolate_antialias,
179
+ interpolate_offset=interpolate_offset,
180
+ block_chunks=block_chunks,
181
+ init_values=init_values,
182
+ )
183
+
184
+ # Disable gradient updates for mask token
185
+ if hasattr(self.patch_embed, "mask_token"):
186
+ self.patch_embed.mask_token.requires_grad_(False)
187
+
188
+ def forward(
189
+ self,
190
+ images: torch.Tensor,
191
+ past_key_values=None,
192
+ use_cache=False,
193
+ past_frame_idx=0
194
+ ) -> Tuple[List[torch.Tensor], int]:
195
+ """
196
+ Args:
197
+ images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
198
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
199
+
200
+ Returns:
201
+ (list[torch.Tensor], int):
202
+ The list of outputs from the attention blocks,
203
+ and the patch_start_idx indicating where patch tokens begin.
204
+ """
205
+ B, S, C_in, H, W = images.shape
206
+
207
+ if use_cache and past_key_values[0] is not None:
208
+ _, _, S_true, _, _ = past_key_values[0][0].shape
209
+ S_true += 1
210
+ else:
211
+ S_true = S
212
+
213
+ if use_cache and S > 1:
214
+ print(f"Use KV cache expects S=1, got S={S}")
215
+
216
+ if C_in != 3:
217
+ raise ValueError(f"Expected 3 input channels, got {C_in}")
218
+
219
+ # Normalize images and reshape for patch embed
220
+ images = (images - self._resnet_mean.to(images.device)) / self._resnet_std.to(images.device)
221
+
222
+ # Reshape to [B*S, C, H, W] for patch embedding
223
+ images = images.reshape(B * S, C_in, H, W)
224
+ patch_tokens = self.patch_embed(images)
225
+
226
+ if isinstance(patch_tokens, dict):
227
+ patch_tokens = patch_tokens["x_norm_patchtokens"]
228
+
229
+ _, P, C = patch_tokens.shape
230
+
231
+ if use_cache:
232
+ camera_token_full = slice_expand_and_flatten(self.camera_token, B, S_true)
233
+ camera_token = camera_token_full[-1:, :, :]
234
+
235
+ register_token_full = slice_expand_and_flatten(self.register_token, B, S_true)
236
+ register_token = register_token_full[-1:, :, :]
237
+ else:
238
+ camera_token = slice_expand_and_flatten(self.camera_token, B, S)
239
+ register_token = slice_expand_and_flatten(self.register_token, B, S)
240
+ # Concatenate special tokens with patch tokens
241
+ tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)
242
+
243
+ pos = None
244
+ if self.rope is not None:
245
+ pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
246
+
247
+ if self.patch_start_idx > 0:
248
+ # do not use position embedding for special tokens (camera and register tokens)
249
+ # so set pos to 0 for the special tokens
250
+ pos = pos + 1
251
+ pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
252
+ pos = torch.cat([pos_special, pos], dim=1)
253
+
254
+ # update P because we added special tokens
255
+ _, P, C = tokens.shape
256
+
257
+ frame_idx = 0
258
+ global_idx = 0
259
+ output_list = []
260
+
261
+ for _ in range(self.aa_block_num):
262
+ for attn_type in self.aa_order:
263
+ if attn_type == "frame":
264
+ tokens, frame_idx, frame_intermediates = self._process_frame_attention(
265
+ tokens, B, S, P, C, frame_idx, pos=pos
266
+ )
267
+ elif attn_type == "global":
268
+ if use_cache:
269
+ if past_key_values[global_idx] is not None:
270
+ k, v = past_key_values[global_idx]
271
+ tokens, global_idx, global_intermediates, new_kv = self._process_global_attention(
272
+ tokens, B, S, P, C, global_idx, pos=pos,
273
+ past_key_values_block=past_key_values[global_idx] if past_key_values[global_idx] is not None else None,
274
+ use_cache=True,
275
+ past_frame_idx=past_frame_idx
276
+ )
277
+ past_key_values[global_idx - 1] = new_kv
278
+ else:
279
+ tokens, global_idx, global_intermediates = self._process_global_attention(
280
+ tokens, B, S, P, C, global_idx, pos=pos
281
+ )
282
+ else:
283
+ raise ValueError(f"Unknown attention type: {attn_type}")
284
+ for i in range(len(frame_intermediates)):
285
+ # concat frame and global intermediates, [B x S x P x 2C]
286
+ concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
287
+ output_list.append(concat_inter)
288
+
289
+ del concat_inter
290
+ del frame_intermediates
291
+ del global_intermediates
292
+ if use_cache:
293
+ return output_list, self.patch_start_idx, past_key_values
294
+ return output_list, self.patch_start_idx
295
+
296
+ def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
297
+ """
298
+ Process frame attention blocks. We keep tokens in shape (B*S, P, C).
299
+ """
300
+ # If needed, reshape tokens or positions:
301
+ if tokens.shape != (B * S, P, C):
302
+ tokens = tokens.reshape(B, S, P, C).reshape(B * S, P, C)
303
+
304
+ if pos is not None and pos.shape != (B * S, P, 2):
305
+ pos = pos.reshape(B, S, P, 2).reshape(B * S, P, 2)
306
+
307
+ intermediates = []
308
+
309
+ # by default, self.aa_block_size=1, which processes one block at a time
310
+ for _ in range(self.aa_block_size):
311
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
312
+ frame_idx += 1
313
+ intermediates.append(tokens.reshape(B, S, P, C))
314
+
315
+ return tokens, frame_idx, intermediates
316
+
317
+ def _process_global_attention(
318
+ self,
319
+ tokens,
320
+ B,
321
+ S,
322
+ P,
323
+ C,
324
+ global_idx,
325
+ pos=None,
326
+ past_key_values_block=None,
327
+ use_cache=False,
328
+ past_frame_idx=0
329
+ ) -> Union[Tuple[torch.Tensor, int, List[torch.Tensor]], Tuple[torch.Tensor, int, List[torch.Tensor], List]]:
330
+ """
331
+ Process global attention blocks. We keep tokens in shape (B, S*P, C).
332
+ """
333
+
334
+ if tokens.shape != (B, S * P, C):
335
+ tokens = tokens.reshape(B, S, P, C).reshape(B, S * P, C)
336
+
337
+ if pos is not None and pos.shape != (B, S * P, 2):
338
+ pos = pos.reshape(B, S, P, 2).reshape(B, S * P, 2)
339
+
340
+ intermediates = []
341
+
342
+ for _ in range(self.aa_block_size):
343
+ if not use_cache:
344
+ L = S * P
345
+ frame_ids = torch.arange(L, device=tokens.device) // P # [0,0,...,1,1,...,S-1]
346
+ future_frame = frame_ids.unsqueeze(1) < frame_ids.unsqueeze(0)
347
+ attn_mask = future_frame.to(tokens.dtype) * torch.finfo(tokens.dtype).min
348
+ else:
349
+ attn_mask = None
350
+
351
+ if use_cache:
352
+ tokens, block_kv = self.global_blocks[global_idx](
353
+ tokens,
354
+ pos=pos,
355
+ attn_mask=attn_mask,
356
+ past_key_values=past_key_values_block,
357
+ use_cache=True
358
+ )
359
+ else:
360
+ tokens = self.global_blocks[global_idx](tokens, pos=pos, attn_mask=attn_mask)
361
+ global_idx += 1
362
+ intermediates.append(tokens.reshape(B, S, P, C))
363
+
364
+ # if self.use_causal_global:
365
+ # del attn_mask
366
+ if use_cache:
367
+ return tokens, global_idx, intermediates, block_kv
368
+ return tokens, global_idx, intermediates
369
+
370
+
371
+ def slice_expand_and_flatten(token_tensor, B, S):
372
+ """
373
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
374
+ 1) Uses the first position (index=0) for the first frame only
375
+ 2) Uses the second position (index=1) for all remaining frames (S-1 frames)
376
+ 3) Expands both to match batch size B
377
+ 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
378
+ followed by (S-1) second-position tokens
379
+ 5) Flattens to (B*S, X, C) for processing
380
+
381
+ Returns:
382
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
383
+ """
384
+
385
+ # Slice out the "query" tokens => shape (1, 1, ...)
386
+ query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
387
+ # Slice out the "other" tokens => shape (1, S-1, ...)
388
+ others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
389
+ # Concatenate => shape (B, S, ...)
390
+ combined = torch.cat([query, others], dim=1)
391
+
392
+ # Finally flatten => shape (B*S, ...)
393
+ combined = combined.reshape(B * S, *combined.shape[2:])
394
+ return combined
streamvggt/models/streamvggt.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from huggingface_hub import PyTorchModelHubMixin # used for model hub
4
+
5
+ from streamvggt.models.aggregator import Aggregator
6
+ from streamvggt.heads.camera_head import CameraHead
7
+ from streamvggt.heads.dpt_head import DPTHead
8
+ from streamvggt.heads.track_head import TrackHead
9
+ from transformers.file_utils import ModelOutput
10
+ from typing import Optional, Tuple, List, Any
11
+ from dataclasses import dataclass
12
+
13
+ @dataclass
14
+ class StreamVGGTOutput(ModelOutput):
15
+ ress: Optional[List[dict]] = None
16
+ views: Optional[torch.Tensor] = None
17
+
18
+ class StreamVGGT(nn.Module, PyTorchModelHubMixin):
19
+ def __init__(self, img_size=518, patch_size=14, embed_dim=1024):
20
+ super().__init__()
21
+
22
+ self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
23
+ self.camera_head = CameraHead(dim_in=2 * embed_dim)
24
+ self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
25
+ self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1")
26
+ self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)
27
+
28
+
29
+
30
+ def forward(
31
+ self,
32
+ views,
33
+ query_points: torch.Tensor = None,
34
+ history_info: Optional[dict] = None,
35
+ past_key_values=None,
36
+ use_cache=False,
37
+ past_frame_idx=0
38
+ ):
39
+ images = torch.stack(
40
+ [view["img"] for view in views], dim=0
41
+ ).permute(1, 0, 2, 3, 4) # B S C H W
42
+
43
+ # If without batch dimension, add it
44
+ if len(images.shape) == 4:
45
+ images = images.unsqueeze(0)
46
+ if query_points is not None and len(query_points.shape) == 2:
47
+ query_points = query_points.unsqueeze(0)
48
+
49
+ if history_info is None:
50
+ history_info = {"token": None}
51
+
52
+ aggregated_tokens_list, patch_start_idx = self.aggregator(images)
53
+ predictions = {}
54
+
55
+ with torch.cuda.amp.autocast(enabled=False):
56
+ if self.camera_head is not None:
57
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
58
+ predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
59
+
60
+ if self.depth_head is not None:
61
+ depth, depth_conf = self.depth_head(
62
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
63
+ )
64
+ predictions["depth"] = depth
65
+ predictions["depth_conf"] = depth_conf
66
+
67
+ if self.point_head is not None:
68
+ pts3d, pts3d_conf = self.point_head(
69
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
70
+ )
71
+ predictions["world_points"] = pts3d
72
+ predictions["world_points_conf"] = pts3d_conf
73
+
74
+ if self.track_head is not None and query_points is not None:
75
+ track_list, vis, conf = self.track_head(
76
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points
77
+ )
78
+ predictions["track"] = track_list[-1] # track of the last iteration
79
+ predictions["vis"] = vis
80
+ predictions["conf"] = conf
81
+ predictions["images"] = images
82
+
83
+ B, S = images.shape[:2]
84
+ ress = []
85
+ for s in range(S):
86
+ res = {
87
+ 'pts3d_in_other_view': predictions['world_points'][:, s], # [B, H, W, 3]
88
+ 'conf': predictions['world_points_conf'][:, s], # [B, H, W]
89
+
90
+ 'depth': predictions['depth'][:, s], # [B, H, W, 1]
91
+ 'depth_conf': predictions['depth_conf'][:, s], # [B, H, W]
92
+ 'camera_pose': predictions['pose_enc'][:, s, :], # [B, 9]
93
+
94
+ **({'valid_mask': views[s]["valid_mask"]}
95
+ if 'valid_mask' in views[s] else {}), # [B, H, W]
96
+
97
+ **({'track': predictions['track'][:, s], # [B, N, 2]
98
+ 'vis': predictions['vis'][:, s], # [B, N]
99
+ 'track_conf': predictions['conf'][:, s]}
100
+ if 'track' in predictions else {})
101
+ }
102
+ ress.append(res)
103
+ return StreamVGGTOutput(ress=ress, views=views) # [S] [B, C, H, W]
104
+
105
+ def inference(self, frames, query_points: torch.Tensor = None, past_key_values=None):
106
+ past_key_values = [None] * self.aggregator.depth
107
+ past_key_values_camera = [None] * self.camera_head.trunk_depth
108
+
109
+ all_ress = []
110
+ processed_frames = []
111
+
112
+ for i, frame in enumerate(frames):
113
+ images = frame["img"].unsqueeze(0)
114
+ aggregator_output = self.aggregator(
115
+ images,
116
+ past_key_values=past_key_values,
117
+ use_cache=True,
118
+ past_frame_idx=i
119
+ )
120
+
121
+ if isinstance(aggregator_output, tuple) and len(aggregator_output) == 3:
122
+ aggregated_tokens, patch_start_idx, past_key_values = aggregator_output
123
+ else:
124
+ aggregated_tokens, patch_start_idx = aggregator_output
125
+
126
+ with torch.cuda.amp.autocast(enabled=False):
127
+ if self.camera_head is not None:
128
+ pose_enc, past_key_values_camera = self.camera_head(aggregated_tokens, past_key_values_camera=past_key_values_camera, use_cache=True)
129
+ pose_enc = pose_enc[-1]
130
+ camera_pose = pose_enc[:, 0, :]
131
+
132
+ if self.depth_head is not None:
133
+ depth, depth_conf = self.depth_head(
134
+ aggregated_tokens, images=images, patch_start_idx=patch_start_idx
135
+ )
136
+ depth = depth[:, 0]
137
+ depth_conf = depth_conf[:, 0]
138
+
139
+ if self.point_head is not None:
140
+ pts3d, pts3d_conf = self.point_head(
141
+ aggregated_tokens, images=images, patch_start_idx=patch_start_idx
142
+ )
143
+ pts3d = pts3d[:, 0]
144
+ pts3d_conf = pts3d_conf[:, 0]
145
+
146
+ if self.track_head is not None and query_points is not None:
147
+ track_list, vis, conf = self.track_head(
148
+ aggregated_tokens, images=images, patch_start_idx=patch_start_idx, query_points=query_points
149
+ )
150
+ track = track_list[-1][:, 0]
151
+ query_points = track
152
+ vis = vis[:, 0]
153
+ track_conf = conf[:, 0]
154
+
155
+ all_ress.append({
156
+ 'pts3d_in_other_view': pts3d,
157
+ 'conf': pts3d_conf,
158
+ 'depth': depth,
159
+ 'depth_conf': depth_conf,
160
+ 'camera_pose': camera_pose,
161
+ **({'valid_mask': frame["valid_mask"]}
162
+ if 'valid_mask' in frame else {}),
163
+
164
+ **({'track': track,
165
+ 'vis': vis,
166
+ 'track_conf': track_conf}
167
+ if query_points is not None else {})
168
+ })
169
+ processed_frames.append(frame)
170
+
171
+ output = StreamVGGTOutput(ress=all_ress, views=processed_frames)
172
+ return output
streamvggt/utils/geometry.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import torch
9
+ import numpy as np
10
+
11
+
12
+ def unproject_depth_map_to_point_map(
13
+ depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
14
+ ) -> np.ndarray:
15
+ """
16
+ Unproject a batch of depth maps to 3D world coordinates.
17
+
18
+ Args:
19
+ depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
20
+ extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
21
+ intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
22
+
23
+ Returns:
24
+ np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
25
+ """
26
+ if isinstance(depth_map, torch.Tensor):
27
+ depth_map = depth_map.cpu().numpy()
28
+ if isinstance(extrinsics_cam, torch.Tensor):
29
+ extrinsics_cam = extrinsics_cam.cpu().numpy()
30
+ if isinstance(intrinsics_cam, torch.Tensor):
31
+ intrinsics_cam = intrinsics_cam.cpu().numpy()
32
+
33
+ world_points_list = []
34
+ for frame_idx in range(depth_map.shape[0]):
35
+ cur_world_points, _, _ = depth_to_world_coords_points(
36
+ depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
37
+ )
38
+ world_points_list.append(cur_world_points)
39
+ world_points_array = np.stack(world_points_list, axis=0)
40
+
41
+ return world_points_array
42
+
43
+
44
+ def depth_to_world_coords_points(
45
+ depth_map: np.ndarray,
46
+ extrinsic: np.ndarray,
47
+ intrinsic: np.ndarray,
48
+ eps=1e-8,
49
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
50
+ """
51
+ Convert a depth map to world coordinates.
52
+
53
+ Args:
54
+ depth_map (np.ndarray): Depth map of shape (H, W).
55
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
56
+ extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
57
+
58
+ Returns:
59
+ tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
60
+ """
61
+ if depth_map is None:
62
+ return None, None, None
63
+
64
+ # Valid depth mask
65
+ point_mask = depth_map > eps
66
+
67
+ # Convert depth map to camera coordinates
68
+ cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
69
+
70
+ # Multiply with the inverse of extrinsic matrix to transform to world coordinates
71
+ # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
72
+ cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
73
+
74
+ R_cam_to_world = cam_to_world_extrinsic[:3, :3]
75
+ t_cam_to_world = cam_to_world_extrinsic[:3, 3]
76
+
77
+ # Apply the rotation and translation to the camera coordinates
78
+ world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
79
+ # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
80
+
81
+ return world_coords_points, cam_coords_points, point_mask
82
+
83
+
84
+ def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
85
+ """
86
+ Convert a depth map to camera coordinates.
87
+
88
+ Args:
89
+ depth_map (np.ndarray): Depth map of shape (H, W).
90
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
91
+
92
+ Returns:
93
+ tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
94
+ """
95
+ H, W = depth_map.shape
96
+ assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
97
+ assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
98
+
99
+ # Intrinsic parameters
100
+ fu, fv = intrinsic[0, 0], intrinsic[1, 1]
101
+ cu, cv = intrinsic[0, 2], intrinsic[1, 2]
102
+
103
+ # Generate grid of pixel coordinates
104
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
105
+
106
+ # Unproject to camera coordinates
107
+ x_cam = (u - cu) * depth_map / fu
108
+ y_cam = (v - cv) * depth_map / fv
109
+ z_cam = depth_map
110
+
111
+ # Stack to form camera coordinates
112
+ cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
113
+
114
+ return cam_coords
115
+
116
+
117
+ def closed_form_inverse_se3(se3, R=None, T=None):
118
+ """
119
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
120
+
121
+ If `R` and `T` are provided, they must correspond to the rotation and translation
122
+ components of `se3`. Otherwise, they will be extracted from `se3`.
123
+
124
+ Args:
125
+ se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
126
+ R (optional): Nx3x3 array or tensor of rotation matrices.
127
+ T (optional): Nx3x1 array or tensor of translation vectors.
128
+
129
+ Returns:
130
+ Inverted SE3 matrices with the same type and device as `se3`.
131
+
132
+ Shapes:
133
+ se3: (N, 4, 4)
134
+ R: (N, 3, 3)
135
+ T: (N, 3, 1)
136
+ """
137
+ # Check if se3 is a numpy array or a torch tensor
138
+ is_numpy = isinstance(se3, np.ndarray)
139
+
140
+ # Validate shapes
141
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
142
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
143
+
144
+ # Extract R and T if not provided
145
+ if R is None:
146
+ R = se3[:, :3, :3] # (N,3,3)
147
+ if T is None:
148
+ T = se3[:, :3, 3:] # (N,3,1)
149
+
150
+ # Transpose R
151
+ if is_numpy:
152
+ # Compute the transpose of the rotation for NumPy
153
+ R_transposed = np.transpose(R, (0, 2, 1))
154
+ # -R^T t for NumPy
155
+ top_right = -np.matmul(R_transposed, T)
156
+ inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
157
+ else:
158
+ R_transposed = R.transpose(1, 2) # (N,3,3)
159
+ top_right = -torch.bmm(R_transposed, T) # (N,3,1)
160
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
161
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
162
+
163
+ inverted_matrix[:, :3, :3] = R_transposed
164
+ inverted_matrix[:, :3, 3:] = top_right
165
+
166
+ return inverted_matrix
streamvggt/utils/load_fn.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from torchvision import transforms as TF
10
+
11
+
12
+ def load_and_preprocess_images(image_path_list, mode="crop"):
13
+ """
14
+ A quick start function to load and preprocess images for model input.
15
+ This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
16
+
17
+ Args:
18
+ image_path_list (list): List of paths to image files
19
+ mode (str, optional): Preprocessing mode, either "crop" or "pad".
20
+ - "crop" (default): Sets width to 518px and center crops height if needed.
21
+ - "pad": Preserves all pixels by making the largest dimension 518px
22
+ and padding the smaller dimension to reach a square shape.
23
+
24
+ Returns:
25
+ torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
26
+
27
+ Raises:
28
+ ValueError: If the input list is empty or if mode is invalid
29
+
30
+ Notes:
31
+ - Images with different dimensions will be padded with white (value=1.0)
32
+ - A warning is printed when images have different shapes
33
+ - When mode="crop": The function ensures width=518px while maintaining aspect ratio
34
+ and height is center-cropped if larger than 518px
35
+ - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
36
+ and the smaller dimension is padded to reach a square shape (518x518)
37
+ - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
38
+ """
39
+ # Check for empty list
40
+ if len(image_path_list) == 0:
41
+ raise ValueError("At least 1 image is required")
42
+
43
+ # Validate mode
44
+ if mode not in ["crop", "pad"]:
45
+ raise ValueError("Mode must be either 'crop' or 'pad'")
46
+
47
+ images = []
48
+ shapes = set()
49
+ to_tensor = TF.ToTensor()
50
+ target_size = 518
51
+
52
+ # First process all images and collect their shapes
53
+ for image_path in image_path_list:
54
+
55
+ # Open image
56
+ img = Image.open(image_path)
57
+
58
+ # If there's an alpha channel, blend onto white background:
59
+ if img.mode == "RGBA":
60
+ # Create white background
61
+ background = Image.new("RGBA", img.size, (255, 255, 255, 255))
62
+ # Alpha composite onto the white background
63
+ img = Image.alpha_composite(background, img)
64
+
65
+ # Now convert to "RGB" (this step assigns white for transparent areas)
66
+ img = img.convert("RGB")
67
+
68
+ width, height = img.size
69
+
70
+ if mode == "pad":
71
+ # Make the largest dimension 518px while maintaining aspect ratio
72
+ if width >= height:
73
+ new_width = target_size
74
+ new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14
75
+ else:
76
+ new_height = target_size
77
+ new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14
78
+ else: # mode == "crop"
79
+ # Original behavior: set width to 518px
80
+ new_width = target_size
81
+ # Calculate height maintaining aspect ratio, divisible by 14
82
+ new_height = round(height * (new_width / width) / 14) * 14
83
+
84
+ # Resize with new dimensions (width, height)
85
+ img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
86
+ img = to_tensor(img) # Convert to tensor (0, 1)
87
+
88
+ # Center crop height if it's larger than 518 (only in crop mode)
89
+ if mode == "crop" and new_height > target_size:
90
+ start_y = (new_height - target_size) // 2
91
+ img = img[:, start_y: start_y + target_size, :]
92
+
93
+ # For pad mode, pad to make a square of target_size x target_size
94
+ if mode == "pad":
95
+ h_padding = target_size - img.shape[1]
96
+ w_padding = target_size - img.shape[2]
97
+
98
+ if h_padding > 0 or w_padding > 0:
99
+ pad_top = h_padding // 2
100
+ pad_bottom = h_padding - pad_top
101
+ pad_left = w_padding // 2
102
+ pad_right = w_padding - pad_left
103
+
104
+ # Pad with white (value=1.0)
105
+ img = torch.nn.functional.pad(
106
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
107
+ )
108
+
109
+ shapes.add((img.shape[1], img.shape[2]))
110
+ images.append(img)
111
+
112
+ # Check if we have different shapes
113
+ # In theory our model can also work well with different shapes
114
+ if len(shapes) > 1:
115
+ print(f"Warning: Found images with different shapes: {shapes}")
116
+ # Find maximum dimensions
117
+ max_height = max(shape[0] for shape in shapes)
118
+ max_width = max(shape[1] for shape in shapes)
119
+
120
+ # Pad images if necessary
121
+ padded_images = []
122
+ for img in images:
123
+ h_padding = max_height - img.shape[1]
124
+ w_padding = max_width - img.shape[2]
125
+
126
+ if h_padding > 0 or w_padding > 0:
127
+ pad_top = h_padding // 2
128
+ pad_bottom = h_padding - pad_top
129
+ pad_left = w_padding // 2
130
+ pad_right = w_padding - pad_left
131
+
132
+ img = torch.nn.functional.pad(
133
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
134
+ )
135
+ padded_images.append(img)
136
+ images = padded_images
137
+
138
+ images = torch.stack(images) # concatenate images
139
+
140
+ # Ensure correct shape when single image
141
+ if len(image_path_list) == 1:
142
+ # Verify shape is (1, C, H, W)
143
+ if images.dim() == 3:
144
+ images = images.unsqueeze(0)
145
+
146
+ return images
vggt/utils/__pycache__/pose_enc.cpython-310.pyc → streamvggt/utils/pose_enc.py RENAMED
Binary files a/vggt/utils/__pycache__/pose_enc.cpython-310.pyc and b/streamvggt/utils/pose_enc.py differ
 
streamvggt/utils/rotation.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
8
+
9
+ import torch
10
+ import numpy as np
11
+ import torch.nn.functional as F
12
+
13
+
14
+ def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
15
+ """
16
+ Quaternion Order: XYZW or say ijkr, scalar-last
17
+
18
+ Convert rotations given as quaternions to rotation matrices.
19
+ Args:
20
+ quaternions: quaternions with real part last,
21
+ as tensor of shape (..., 4).
22
+
23
+ Returns:
24
+ Rotation matrices as tensor of shape (..., 3, 3).
25
+ """
26
+ i, j, k, r = torch.unbind(quaternions, -1)
27
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
28
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
29
+
30
+ o = torch.stack(
31
+ (
32
+ 1 - two_s * (j * j + k * k),
33
+ two_s * (i * j - k * r),
34
+ two_s * (i * k + j * r),
35
+ two_s * (i * j + k * r),
36
+ 1 - two_s * (i * i + k * k),
37
+ two_s * (j * k - i * r),
38
+ two_s * (i * k - j * r),
39
+ two_s * (j * k + i * r),
40
+ 1 - two_s * (i * i + j * j),
41
+ ),
42
+ -1,
43
+ )
44
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
45
+
46
+
47
+ def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
48
+ """
49
+ Convert rotations given as rotation matrices to quaternions.
50
+
51
+ Args:
52
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
53
+
54
+ Returns:
55
+ quaternions with real part last, as tensor of shape (..., 4).
56
+ Quaternion Order: XYZW or say ijkr, scalar-last
57
+ """
58
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
59
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
60
+
61
+ batch_dim = matrix.shape[:-2]
62
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
63
+
64
+ q_abs = _sqrt_positive_part(
65
+ torch.stack(
66
+ [
67
+ 1.0 + m00 + m11 + m22,
68
+ 1.0 + m00 - m11 - m22,
69
+ 1.0 - m00 + m11 - m22,
70
+ 1.0 - m00 - m11 + m22,
71
+ ],
72
+ dim=-1,
73
+ )
74
+ )
75
+
76
+ # we produce the desired quaternion multiplied by each of r, i, j, k
77
+ quat_by_rijk = torch.stack(
78
+ [
79
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
80
+ # `int`.
81
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
82
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
83
+ # `int`.
84
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
85
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
86
+ # `int`.
87
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
88
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
89
+ # `int`.
90
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
91
+ ],
92
+ dim=-2,
93
+ )
94
+
95
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
96
+ # the candidate won't be picked.
97
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
98
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
99
+
100
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
101
+ # forall i; we pick the best-conditioned one (with the largest denominator)
102
+ out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
103
+
104
+ # Convert from rijk to ijkr
105
+ out = out[..., [1, 2, 3, 0]]
106
+
107
+ out = standardize_quaternion(out)
108
+
109
+ return out
110
+
111
+
112
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
113
+ """
114
+ Returns torch.sqrt(torch.max(0, x))
115
+ but with a zero subgradient where x is 0.
116
+ """
117
+ ret = torch.zeros_like(x)
118
+ positive_mask = x > 0
119
+ if torch.is_grad_enabled():
120
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
121
+ else:
122
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
123
+ return ret
124
+
125
+
126
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
127
+ """
128
+ Convert a unit quaternion to a standard form: one in which the real
129
+ part is non negative.
130
+
131
+ Args:
132
+ quaternions: Quaternions with real part last,
133
+ as tensor of shape (..., 4).
134
+
135
+ Returns:
136
+ Standardized quaternions as tensor of shape (..., 4).
137
+ """
138
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
streamvggt/utils/visual_track.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import cv2
8
+ import torch
9
+ import numpy as np
10
+ import os
11
+
12
+
13
+ def color_from_xy(x, y, W, H, cmap_name="hsv"):
14
+ """
15
+ Map (x, y) -> color in (R, G, B).
16
+ 1) Normalize x,y to [0,1].
17
+ 2) Combine them into a single scalar c in [0,1].
18
+ 3) Use matplotlib's colormap to convert c -> (R,G,B).
19
+
20
+ You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).
21
+ """
22
+ import matplotlib.cm
23
+ import matplotlib.colors
24
+
25
+ x_norm = x / max(W - 1, 1)
26
+ y_norm = y / max(H - 1, 1)
27
+ # Simple combination:
28
+ c = (x_norm + y_norm) / 2.0
29
+
30
+ cmap = matplotlib.cm.get_cmap(cmap_name)
31
+ # cmap(c) -> (r,g,b,a) in [0,1]
32
+ rgba = cmap(c)
33
+ r, g, b = rgba[0], rgba[1], rgba[2]
34
+ return (r, g, b) # in [0,1], RGB order
35
+
36
+
37
+ def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"):
38
+ """
39
+ Given all tracks in one sample (b), compute a (N,3) array of RGB color values
40
+ in [0,255]. The color is determined by the (x,y) position in the first
41
+ visible frame for each track.
42
+
43
+ Args:
44
+ tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.
45
+ vis_mask_b: (S, N) boolean mask; if None, assume all are visible.
46
+ image_width, image_height: used for normalizing (x, y).
47
+ cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').
48
+
49
+ Returns:
50
+ track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].
51
+ """
52
+ S, N, _ = tracks_b.shape
53
+ track_colors = np.zeros((N, 3), dtype=np.uint8)
54
+
55
+ if vis_mask_b is None:
56
+ # treat all as visible
57
+ vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)
58
+
59
+ for i in range(N):
60
+ # Find first visible frame for track i
61
+ visible_frames = torch.where(vis_mask_b[:, i])[0]
62
+ if len(visible_frames) == 0:
63
+ # track is never visible; just assign black or something
64
+ track_colors[i] = (0, 0, 0)
65
+ continue
66
+
67
+ first_s = int(visible_frames[0].item())
68
+ # use that frame's (x,y)
69
+ x, y = tracks_b[first_s, i].tolist()
70
+
71
+ # map (x,y) -> (R,G,B) in [0,1]
72
+ r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name)
73
+ # scale to [0,255]
74
+ r, g, b = int(r * 255), int(g * 255), int(b * 255)
75
+ track_colors[i] = (r, g, b)
76
+
77
+ return track_colors
78
+
79
+
80
+ def visualize_tracks_on_images(
81
+ images,
82
+ tracks,
83
+ track_vis_mask=None,
84
+ out_dir="track_visuals_concat_by_xy",
85
+ image_format="CHW", # "CHW" or "HWC"
86
+ normalize_mode="[0,1]",
87
+ cmap_name="hsv", # e.g. "hsv", "rainbow", "jet"
88
+ frames_per_row=4, # New parameter for grid layout
89
+ save_grid=True, # Flag to control whether to save the grid image
90
+ ):
91
+ """
92
+ Visualizes frames in a grid layout with specified frames per row.
93
+ Each track's color is determined by its (x,y) position
94
+ in the first visible frame (or frame 0 if always visible).
95
+ Finally convert the BGR result to RGB before saving.
96
+ Also saves each individual frame as a separate PNG file.
97
+
98
+ Args:
99
+ images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC.
100
+ tracks: torch.Tensor (S, N, 2), last dim = (x, y).
101
+ track_vis_mask: torch.Tensor (S, N) or None.
102
+ out_dir: folder to save visualizations.
103
+ image_format: "CHW" or "HWC".
104
+ normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255
105
+ cmap_name: a matplotlib colormap name for color_from_xy.
106
+ frames_per_row: number of frames to display in each row of the grid.
107
+ save_grid: whether to save all frames in one grid image.
108
+
109
+ Returns:
110
+ None (saves images in out_dir).
111
+ """
112
+
113
+ if len(tracks.shape) == 4:
114
+ tracks = tracks.squeeze(0)
115
+ images = images.squeeze(0)
116
+ if track_vis_mask is not None:
117
+ track_vis_mask = track_vis_mask.squeeze(0)
118
+
119
+ import matplotlib
120
+
121
+ matplotlib.use("Agg") # for non-interactive (optional)
122
+
123
+ os.makedirs(out_dir, exist_ok=True)
124
+
125
+ S = images.shape[0]
126
+ _, N, _ = tracks.shape # (S, N, 2)
127
+
128
+ # Move to CPU
129
+ images = images.cpu().clone()
130
+ tracks = tracks.cpu().clone()
131
+ if track_vis_mask is not None:
132
+ track_vis_mask = track_vis_mask.cpu().clone()
133
+
134
+ # Infer H, W from images shape
135
+ if image_format == "CHW":
136
+ # e.g. images[s].shape = (3, H, W)
137
+ H, W = images.shape[2], images.shape[3]
138
+ else:
139
+ # e.g. images[s].shape = (H, W, 3)
140
+ H, W = images.shape[1], images.shape[2]
141
+
142
+ # Pre-compute the color for each track i based on first visible position
143
+ track_colors_rgb = get_track_colors_by_position(
144
+ tracks, # shape (S, N, 2)
145
+ vis_mask_b=track_vis_mask if track_vis_mask is not None else None,
146
+ image_width=W,
147
+ image_height=H,
148
+ cmap_name=cmap_name,
149
+ )
150
+
151
+ # We'll accumulate each frame's drawn image in a list
152
+ frame_images = []
153
+
154
+ for s in range(S):
155
+ # shape => either (3, H, W) or (H, W, 3)
156
+ img = images[s]
157
+
158
+ # Convert to (H, W, 3)
159
+ if image_format == "CHW":
160
+ img = img.permute(1, 2, 0) # (H, W, 3)
161
+ # else "HWC", do nothing
162
+
163
+ img = img.numpy().astype(np.float32)
164
+
165
+ # Scale to [0,255] if needed
166
+ if normalize_mode == "[0,1]":
167
+ img = np.clip(img, 0, 1) * 255.0
168
+ elif normalize_mode == "[-1,1]":
169
+ img = (img + 1.0) * 0.5 * 255.0
170
+ img = np.clip(img, 0, 255.0)
171
+ # else no normalization
172
+
173
+ # Convert to uint8
174
+ img = img.astype(np.uint8)
175
+
176
+ # For drawing in OpenCV, convert to BGR
177
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
178
+
179
+ # Draw each visible track
180
+ cur_tracks = tracks[s] # shape (N, 2)
181
+ if track_vis_mask is not None:
182
+ valid_indices = torch.where(track_vis_mask[s])[0]
183
+ else:
184
+ valid_indices = range(N)
185
+
186
+ cur_tracks_np = cur_tracks.numpy()
187
+ for i in valid_indices:
188
+ x, y = cur_tracks_np[i]
189
+ pt = (int(round(x)), int(round(y)))
190
+
191
+ # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR
192
+ R, G, B = track_colors_rgb[i]
193
+ color_bgr = (int(B), int(G), int(R))
194
+ cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)
195
+
196
+ # Convert back to RGB for consistent final saving:
197
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
198
+
199
+ # Save individual frame
200
+ frame_path = os.path.join(out_dir, f"frame_{s:04d}.png")
201
+ # Convert to BGR for OpenCV imwrite
202
+ frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
203
+ cv2.imwrite(frame_path, frame_bgr)
204
+
205
+ frame_images.append(img_rgb)
206
+
207
+ # Only create and save the grid image if save_grid is True
208
+ if save_grid:
209
+ # Calculate grid dimensions
210
+ num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division
211
+
212
+ # Create a grid of images
213
+ grid_img = None
214
+ for row in range(num_rows):
215
+ start_idx = row * frames_per_row
216
+ end_idx = min(start_idx + frames_per_row, S)
217
+
218
+ # Concatenate this row horizontally
219
+ row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1)
220
+
221
+ # If this row has fewer than frames_per_row images, pad with black
222
+ if end_idx - start_idx < frames_per_row:
223
+ padding_width = (frames_per_row - (end_idx - start_idx)) * W
224
+ padding = np.zeros((H, padding_width, 3), dtype=np.uint8)
225
+ row_img = np.concatenate([row_img, padding], axis=1)
226
+
227
+ # Add this row to the grid
228
+ if grid_img is None:
229
+ grid_img = row_img
230
+ else:
231
+ grid_img = np.concatenate([grid_img, row_img], axis=0)
232
+
233
+ out_path = os.path.join(out_dir, "tracks_grid.png")
234
+ # Convert back to BGR for OpenCV imwrite
235
+ grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR)
236
+ cv2.imwrite(out_path, grid_img_bgr)
237
+ print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}")
238
+
239
+ print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")
vggt/heads/__pycache__/camera_head.cpython-310.pyc DELETED
Binary file (4.27 kB)
 
vggt/heads/__pycache__/camera_head.cpython-311.pyc DELETED
Binary file (6.8 kB)
 
vggt/heads/__pycache__/camera_head.cpython-312.pyc DELETED
Binary file (6.13 kB)
 
vggt/heads/__pycache__/dpt_head.cpython-310.pyc DELETED
Binary file (12.6 kB)
 
vggt/heads/__pycache__/dpt_head.cpython-311.pyc DELETED
Binary file (21.7 kB)
 
vggt/heads/__pycache__/dpt_head.cpython-312.pyc DELETED
Binary file (20.3 kB)
 
vggt/heads/__pycache__/head_act.cpython-311.pyc DELETED
Binary file (4.87 kB)