diff --git "a/modeling_spear.py" "b/modeling_spear.py" new file mode 100644--- /dev/null +++ "b/modeling_spear.py" @@ -0,0 +1,2766 @@ +import collections +import functools +import inspect +import logging +import math +import warnings +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Protocol, + Set, + Tuple, + Type, +) + +import einops +import numpy as np +import roma +import timm +import torch +import torch.distributed.fsdp +import torch.distributed.tensor +import transformers +from huggingface_hub import hf_hub_download + +from .common_spear import ( + Configurable, + DiffusionInput, + FlowInput, + LLMOutput, + RoboticsFlowInput, + RoboticsInput, + RoboticsOutput, + RotationFormat, + VLMOutput, + expand_dims, + is_quaternion, + is_rotmat, + is_rotmat_3x3, + is_rotmat_9, + quaternion_half_cover, + rotmat_as_3x3, + rotmat_as_9, +) +from .configuration_spear import ( + FourierFeaturesConfig, + NoisedControlProjectorConfig, + PaliGemmaVLMConfig, + PiZeroFlowMatchingDecoderBlockConfig, + PiZeroFlowMatchingDecoderConfig, + PiZeroFlowMatchingModuleConfig, + RobotStateProjectorConfig, + RotaryPositionalEncodingConfig, + SPEAR1Config, +) +from .processing_spear import ( + EmptyTokenizer, + PaliGemmaDepthProcessor, + PiZeroFlowMatchingProcessor, +) + + +class ConfigurableModule(torch.nn.Module, Configurable): + def __init__(self, config): + Configurable.__init__(self, config) + torch.nn.Module.__init__(self) + + +class GemmaRMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-06): + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +@torch.no_grad() +def make_module_params_non_trainable( + module: torch.nn.Module, recursive: bool = True, dtype: Optional[torch.dtype] = None +): + """ + NOTE: dtype is applied only to module parameters, not buffers. This is different from the + default torch.nn.Module.to(dtype=dtype) behavior, which applies to both parameters and buffers + """ + for param in module.parameters(recurse=recursive): + param.requires_grad = False + if dtype is not None: + param.data = param.to(dtype=dtype) + + +@torch.no_grad() +def make_module_non_trainable(module: torch.nn.Module, dtype: Optional[torch.dtype] = None): + make_module_params_non_trainable(module, dtype=None) + if dtype is not None: + module.to(dtype=dtype) + module.eval() + + +class ResidualConvBlock(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + hidden_channels: int | None = None, + padding_mode: str = "replicate", + activation: Literal["relu", "leaky_relu", "silu", "elu"] = "relu", + norm: Literal["group_norm", "layer_norm"] = "group_norm", + ): + super().__init__() + if out_channels is None: + out_channels = in_channels + if hidden_channels is None: + hidden_channels = in_channels + if activation == "relu": + activation_cls = functools.partial(torch.nn.ReLU, inplace=True) + elif activation == "leaky_relu": + activation_cls = functools.partial(torch.nn.LeakyReLU, negative_slope=0.2, inplace=True) + elif activation == "silu": + activation_cls = functools.partial(torch.nn.SiLU, inplace=True) + elif activation == "elu": + activation_cls = functools.partial(torch.nn.ELU, inplace=True) + else: + raise ValueError(f"Unsupported activation function: {activation}") + self.layers = torch.nn.Sequential( + torch.nn.GroupNorm(1, in_channels), + activation_cls(), + torch.nn.Conv2d( + in_channels, + hidden_channels, + kernel_size=3, + padding=1, + padding_mode=padding_mode, + ), + torch.nn.GroupNorm(hidden_channels // 32 if norm == "group_norm" else 1, hidden_channels), + activation_cls(), + torch.nn.Conv2d( + hidden_channels, + out_channels, + kernel_size=3, + padding=1, + padding_mode=padding_mode, + ), + ) + self.skip_connection = ( + torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) + if in_channels != out_channels + else torch.nn.Identity() + ) + + def forward(self, x): + skip = self.skip_connection(x) + x = self.layers(x) + x = x + skip + return x + + +def normalized_view_plane_uv( + width: int, + height: int, + aspect_ratio: float | None = None, + dtype: torch.dtype = None, + device: torch.device = None, +) -> torch.Tensor: + """ + UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal) + """ + if aspect_ratio is None: + aspect_ratio = width / height + span_x = aspect_ratio / (1 + aspect_ratio**2) ** 0.5 + span_y = 1 / (1 + aspect_ratio**2) ** 0.5 + u = torch.linspace( + -span_x * (width - 1) / width, + span_x * (width - 1) / width, + width, + dtype=dtype, + device=device, + ) + v = torch.linspace( + -span_y * (height - 1) / height, + span_y * (height - 1) / height, + height, + dtype=dtype, + device=device, + ) + (u, v) = torch.meshgrid(u, v, indexing="xy") + uv = torch.stack([u, v], dim=-1) + return uv + + +class Head(torch.nn.Module): + def __init__( + self, + num_features: int, + dim_in: int, + dim_out: List[int], + dim_proj: int = 512, + dim_upsample: List[int] = [256, 128, 128], + dim_times_res_block_hidden: int = 1, + num_res_blocks: int = 1, + res_block_norm: Literal["group_norm", "layer_norm"] = "group_norm", + last_res_blocks: int = 0, + last_conv_channels: int = 32, + last_conv_size: int = 1, + ): + super().__init__() + self.projects = torch.nn.ModuleList( + [ + torch.nn.Conv2d( + in_channels=dim_in, + out_channels=dim_proj, + kernel_size=1, + stride=1, + padding=0, + ) + for _ in range(num_features) + ] + ) + self.upsample_blocks = torch.nn.ModuleList( + [ + torch.nn.Sequential( + self._make_upsampler(in_ch + 2, out_ch), + *( + ResidualConvBlock( + out_ch, + out_ch, + dim_times_res_block_hidden * out_ch, + activation="relu", + norm=res_block_norm, + ) + for _ in range(num_res_blocks) + ), + ) + for (in_ch, out_ch) in zip([dim_proj] + dim_upsample[:-1], dim_upsample, strict=True) + ] + ) + self.output_block = torch.nn.ModuleList( + [ + self._make_output_block( + dim_upsample[-1] + 2, + dim_out_, + dim_times_res_block_hidden, + last_res_blocks, + last_conv_channels, + last_conv_size, + res_block_norm, + ) + for dim_out_ in dim_out + ] + ) + + def _make_upsampler(self, in_channels: int, out_channels: int): + upsampler = torch.nn.Sequential( + torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), + torch.nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + ) + upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1] + return upsampler + + def _make_output_block( + self, + dim_in: int, + dim_out: int, + dim_times_res_block_hidden: int, + last_res_blocks: int, + last_conv_channels: int, + last_conv_size: int, + res_block_norm: Literal["group_norm", "layer_norm"], + ): + return torch.nn.Sequential( + torch.nn.Conv2d( + dim_in, + last_conv_channels, + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + *( + ResidualConvBlock( + last_conv_channels, + last_conv_channels, + dim_times_res_block_hidden * last_conv_channels, + activation="relu", + norm=res_block_norm, + ) + for _ in range(last_res_blocks) + ), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d( + last_conv_channels, + dim_out, + kernel_size=last_conv_size, + stride=1, + padding=last_conv_size // 2, + padding_mode="replicate", + ), + ) + + def forward(self, hidden_states: List[torch.Tensor], image: torch.Tensor): + (img_h, img_w) = image.shape[-2:] + (patch_h, patch_w) = (img_h // 14, img_w // 14) + x = torch.stack( + [ + proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous()) + for (proj, feat) in zip(self.projects, hidden_states, strict=True) + ], + dim=1, + ).sum(dim=1) + for _, block in enumerate(self.upsample_blocks): + uv = normalized_view_plane_uv( + width=x.shape[-1], + height=x.shape[-2], + aspect_ratio=img_w / img_h, + dtype=x.dtype, + device=x.device, + ) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) + x = torch.cat([x, uv], dim=1) + for layer in block: + x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False) + x = torch.nn.functional.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False) + uv = normalized_view_plane_uv( + width=x.shape[-1], + height=x.shape[-2], + aspect_ratio=img_w / img_h, + dtype=x.dtype, + device=x.device, + ) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) + x = torch.cat([x, uv], dim=1) + if isinstance(self.output_block, torch.nn.ModuleList): + output = [ + torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) + for block in self.output_block + ] + else: + output = torch.utils.checkpoint.checkpoint(self.output_block, x, use_reentrant=False) + return output + + +def _is_single_image_size(image_sizes: Dict[str, Dict[str, int]]) -> bool: + return ( + len(image_sizes) == 1 + or len(set(((image_size["height"], image_size["width"]) for image_size in image_sizes.values()))) == 1 + ) + + +class MoGe(torch.nn.Module): + """ + Implementation of MoGe taken from https://github.com/microsoft/MoGe/blob/main/moge/model/v1.py + Simplified and stripped down such that: + - It doesn't rely on MoGe codebase + - Uses timm for ViT backbone + - Only predicts points and mask + - Currently does NOT infer depth or intrinsics (but this could be added) + - Requires image to be resized to the expected resolution. Note that this resolution must be + in the range of num_tokens_range, (for square images, pixel sizes in the range ~[36*14, 50*14]) + """ + + def __init__( + self, + image_sizes: Dict[str, Dict[str, int]] = {}, + backbone_id: str = "vit_large_patch14_dinov2.lvd142m", + intermediate_layers: int | List[int] = 4, + dim_proj: int = 512, + dim_upsample: List[int] = [256, 128, 64], + dim_times_res_block_hidden: int = 2, + num_res_blocks: int = 2, + remap_output: Literal[False, True, "linear", "sinh", "exp", "sinh_exp"] = "exp", + res_block_norm: Literal["group_norm", "layer_norm"] = "group_norm", + num_tokens_range: Tuple[int, int] = [1275, 2551], + last_res_blocks: int = 0, + last_conv_channels: int = 32, + last_conv_size: int = 1, + mask_threshold: float = 0.5, + output_dtype: torch.dtype = torch.bfloat16, + ): + """ + NOTE: + - All defaults were taken from the checkpoint config for 'Ruicheng/moge-vitl' + - output_dtype - by default was float32, changed to bfloat16 for model training + """ + super().__init__() + self.remap_output = remap_output + self.intermediate_layers = intermediate_layers + self.num_tokens_range = num_tokens_range + self.mask_threshold = mask_threshold + self.output_dtype = output_dtype + self.image_sizes = dict(image_sizes) + self.backbone: timm.models.vision_transformer.VisionTransformer = self._make_vit_backbone(backbone_id) + token_size: int = self.backbone.embed_dim + self.head = Head( + num_features=( + intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers) + ), + dim_in=token_size, + dim_out=[3, 1], + dim_proj=dim_proj, + dim_upsample=dim_upsample, + dim_times_res_block_hidden=dim_times_res_block_hidden, + num_res_blocks=num_res_blocks, + res_block_norm=res_block_norm, + last_res_blocks=last_res_blocks, + last_conv_channels=last_conv_channels, + last_conv_size=last_conv_size, + ) + + def _make_vit_backbone(self, backbone_id: str) -> timm.models.vision_transformer.VisionTransformer: + if _is_single_image_size(self.image_sizes): + kwargs = { + "img_size": ( + self.image_sizes["main"]["height"], + self.image_sizes["main"]["width"], + ), + "dynamic_img_size": False, + } + else: + kwargs = {"img_size": (224, 224), "dynamic_img_size": True} + vit_backbone: timm.models.vision_transformer.VisionTransformer = timm.create_model( + backbone_id, pretrained=False, num_classes=0, **kwargs + ) + vit_backbone.forward = functools.partial( + vit_backbone.forward_intermediates, + indices=4, + return_prefix_tokens=False, + norm=True, + stop_early=True, + output_fmt="NLC", + intermediates_only=True, + ) + return vit_backbone + + def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]: + """ + Args: + image: torch.Tensor of shape [B, 3, H, W] containing the preprocessed image, resized to the + size expected by the model. + Returns: + A dictionary containing: + - `points`: torch.Tensor of shape [B, 3, H, W] containing the predicted points. + - `mask`: torch.Tensor of shape [B, 1, H, W] containing the predicted mask. + """ + (height, width) = image.shape[-2:] + assert (height, width) in [ + (image_size["height"], image_size["width"]) for image_size in self.image_sizes.values() + ], f"{(height, width)} not in {self.image_sizes}" + features: List[torch.Tensor] = self.backbone(image) + output = self.head(features, image) + (points, mask) = output + with torch.autocast( + device_type=image.device.type, + dtype=torch.float32, + enabled=self.output_dtype == torch.float32, + ): + points = torch.nn.functional.interpolate( + points, + (height, width), + mode="bilinear", + align_corners=False, + antialias=False, + ) + mask = torch.nn.functional.interpolate( + mask, + (height, width), + mode="bilinear", + align_corners=False, + antialias=False, + ) + points = self._remap_points(points, dim=1) + output = {"points": points, "mask": mask} + return output + + def _remap_points(self, points: torch.Tensor, dim: int = 1) -> torch.Tensor: + if self.remap_output == "linear": + pass + elif self.remap_output == "sinh": + points = torch.sinh(points) + elif self.remap_output == "exp": + (xy, z) = points.split([2, 1], dim=dim) + z = torch.exp(z) + points = torch.cat([xy * z, z], dim=dim) + elif self.remap_output == "sinh_exp": + (xy, z) = points.split([2, 1], dim=dim) + points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=dim) + else: + raise ValueError(f"Invalid remap output type: {self.remap_output}") + return points + + +class DepthBackboneConfig(transformers.PretrainedConfig): + def __init__( + self, + hf_hub_repo: str = "", + hf_filename: str = "", + image_sizes: Dict[str, Dict[str, int]] = {}, + **kwargs, + ): + super().__init__(**kwargs) + self.hf_hub_repo = hf_hub_repo + self.hf_filename = hf_filename + self.image_sizes = dict(image_sizes) + + +class PaliGemma3DConfig(transformers.models.paligemma.PaliGemmaConfig): + sub_configs = { + "text_config": transformers.AutoConfig, + "vision_config": transformers.AutoConfig, + "depth_config": DepthBackboneConfig, + } + + def __init__( + self, + depth_config={}, + depth_only: bool = False, + mask_prob: float = 0.0, + projection: str = "", + depth_layers: int = 4, + **kwargs, + ): + super().__init__(**kwargs) + if isinstance(depth_config, dict): + self.depth_config = DepthBackboneConfig(**depth_config) + else: + self.depth_config = depth_config + self.mask_prob = mask_prob + self.depth_only = depth_only + self.projection = projection + self.depth_layers = depth_layers + + @property + def is_single_image_size(self) -> bool: + return ( + len(self.depth_config.image_sizes) == 1 + or len( + set( + ( + (image_size["height"], image_size["width"]) + for image_size in self.depth_config.image_sizes.values() + ) + ) + ) + == 1 + ) + + @property + def camera_names(self) -> List[str]: + return list(self.depth_config.image_sizes.keys()) + + +class NeRFPositionalEmbedding(torch.nn.Module): + def __init__(self, n_frequencies: int, log_scale: bool = True, scale: float = 1.0): + """ + Args: + n_frequencies: Dimension size, same as L parameter in the NeRF paper + scale: Scale factor for the frequencies. To match the formula from the paper + [sin(2^k * pi * x), cos(2^k * pi * x)], set scale to math.pi. In practice, the paper + authors don't multiply by pi and use scale=1.0. + See https://github.com/bmild/nerf/issues/12 + """ + super().__init__() + self.n_frequencies = n_frequencies + if log_scale: + freq = 2 ** torch.arange(self.n_frequencies, dtype=torch.float32) * scale + else: + freq = torch.linspace(1, 2 ** (self.n_frequencies - 1), self.n_frequencies) * scale + self.register_buffer("freq", freq) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Maps features from dimensionality N to dimensionality N*(2*L + 1), i.e. `2L + 1` output + features are produced for each input feature. + + Embeds x to (x, sin(2^k x), cos(2^k x), ...). NOTE: `x` is also in the output. This is + different from the equation in the paper, but matches the actual author implementation. + See https://github.com/bmild/nerf/issues/12 + + Args: + inputs: torch.Tensor of shape [B, ..., N]; input values to be transformed + Returns: torch.Tensor of shape [B, ..., N*(2*L + 1) = embedding_dim], encoded input values + """ + freq = expand_dims(self.freq, ndim=inputs.ndim + 1, order=[-1, 1]) + spectrum = freq * inputs.unsqueeze(-1) + sin = torch.sin(spectrum) + cos = torch.cos(spectrum) + encoding = torch.stack([sin, cos], dim=-1) + encoding = encoding.view(*inputs.shape[:-1], -1) + encoding = torch.cat([inputs, encoding], dim=-1) + return encoding + + +def make_mlp( + layer_sizes: List[int], + activation: str | Type[torch.nn.Module], + norm: str | Type[torch.nn.Module] | None = torch.nn.LayerNorm, + activate_final: bool = False, + bias: bool = True, +) -> torch.nn.Sequential: + """ + Args: + layer_sizes: List of layer sizes. The first value is the number of input features and the last + value is the number of output features + activation: str or the class of the activation. If str, it should be the exact name of + the activation module under torch.nn, e.g. 'ReLU', 'SiLU', 'GeLU'. Use 'Identity' if + no activation wanted + norm: type of normalization. Same type as `activation`. Ex: `torch.nn.LayerNorm`, 'LayerNorm', etc + """ + if len(layer_sizes) == 0: + return torch.nn.Identity() + assert len(layer_sizes) > 1, "Need to provide input and output layer sizes at least" + if isinstance(activation, str): + TorchActivation: Type[torch.nn.Module] = getattr(torch.nn, activation) + else: + TorchActivation: Type[torch.nn.Module] = activation + assert issubclass(TorchActivation, torch.nn.Module), TorchActivation + if isinstance(norm, str): + TorchNorm: Type[torch.nn.Module] = getattr(torch.nn, norm) + elif norm is None: + TorchNorm: Type[torch.nn.Module] = torch.nn.Identity + else: + TorchNorm: Type[torch.nn.Module] = norm + assert issubclass(TorchNorm, torch.nn.Module), TorchNorm + + def make_norm_act(modules: dict[str, torch.nn.Module], empty: bool): + return {} if empty else modules + + module = torch.nn.Sequential( + *[ + torch.nn.Sequential( + collections.OrderedDict( + { + "linear": torch.nn.Linear(in_features, out_features, bias=bias), + **make_norm_act( + {"norm": TorchNorm(out_features), "act": TorchActivation()}, + empty=i == len(layer_sizes) - 2 and not activate_final, + ), + } + ) + ) + for (i, (in_features, out_features)) in enumerate( + zip(layer_sizes[:-1], layer_sizes[1:], strict=True) + ) + ] + ) + return module + + +class PaliGemma3D(transformers.models.paligemma.PaliGemmaForConditionalGeneration): + """ + Transformers-like implementation of PaliGemma with additional depth encoder + """ + + config_class = PaliGemma3DConfig + + def __init__(self, config: PaliGemma3DConfig): + super().__init__(config) + assert self.config.projection in ["features_add"] + if self.config.projection in ["features_add"]: + self.depth_tower = self._make_depth_encoder() + self.depth_projector = torch.nn.Linear( + in_features=self.depth_tower.embed_dim * self.config.depth_layers, + out_features=self.config.text_config.hidden_size, + ) + self.generator: torch.Generator = torch.Generator() + if self.config.depth_only: + make_module_non_trainable(self.vision_tower) + make_module_non_trainable(self.siglip_projector) + self.projector = None + else: + raise ValueError(f"Projection type `{self.config.projection}` not supported!") + + @property + def projectors(self) -> Set[torch.nn.Module]: + modules = set( + ( + module + for module in [ + self.projector, + self.depth_projector, + self.siglip_projector, + ] + if module is not None + ) + ) + if isinstance(self.depth_tower, MoGe): + modules = modules | { + module + for module in self.depth_tower.modules() + if isinstance(module, (ResidualConvBlock, Head)) + } + return modules + + @property + def siglip_projector(self) -> torch.nn.Linear: + return self.multi_modal_projector.linear + + def _make_depth_encoder(self) -> torch.nn.Module: + if self.config.is_single_image_size: + kwargs = { + "img_size": ( + self.config.depth_config.image_sizes["main"]["height"], + self.config.depth_config.image_sizes["main"]["width"], + ), + "dynamic_img_size": False, + } + else: + kwargs = { + "img_size": ( + self.config.depth_config.image_sizes["main"]["height"], + self.config.depth_config.image_sizes["main"]["width"], + ), + "dynamic_img_size": True, + } + model: timm.models.vision_transformer.VisionTransformer = timm.create_model( + "vit_large_patch14_dinov2.lvd142m", + pretrained=False, + num_classes=0, + **kwargs, + ) + model.forward = functools.partial( + model.forward_intermediates, + indices=self.config.depth_layers, + return_prefix_tokens=False, + norm=True, + stop_early=True, + output_fmt="NLC", + intermediates_only=True, + ) + return model + + def _load_depth_model_state_dict(self, depth_model: torch.nn.Module): + logging.info( + f"Loading depth model from {self.config.depth_config.hf_hub_repo}/{self.config.depth_config.hf_filename}" + ) + state_dict = torch.load( + hf_hub_download( + repo_id=self.config.depth_config.hf_hub_repo, + filename=self.config.depth_config.hf_filename, + ), + map_location="cpu", + mmap=True, + weights_only=False, + ) + if self.config.projection in ["spatial_add", "spatial_concat"]: + pos_embed_state_dict = {"pos_embed": state_dict["backbone.pos_embed"]} + pos_embed_state_dict = timm.models.vision_transformer.checkpoint_filter_fn( + pos_embed_state_dict, depth_model.backbone + ) + state_dict["backbone.pos_embed"] = pos_embed_state_dict["pos_embed"] + else: + state_dict = timm.models.vision_transformer.checkpoint_filter_fn(state_dict, depth_model) + depth_model.load_state_dict(state_dict) + + def get_image_features(self, pixel_values: dict[str, torch.Tensor]) -> torch.Tensor: + if self.config.projection == "features_add": + images_forward = self._get_image_features_add + elif self.config.projection in ["spatial_add", "spatial_concat"]: + images_forward = self._get_image_features_spatial + else: + raise ValueError(f"Project type `{self.config.projection}` not supported!") + camera_names = self.config.camera_names + if self.config.is_single_image_size: + inputs = { + "siglip": einops.rearrange( + torch.stack( + [pixel_values[f"{camera_name}.siglip"] for camera_name in camera_names], + dim=1, + ), + "B N C H W -> (B N) C H W", + ), + "depth": einops.rearrange( + torch.stack( + [pixel_values[f"{camera_name}.depth"] for camera_name in camera_names], + dim=1, + ), + "B N C H W -> (B N) C H W", + ), + } + image_tokens = images_forward(inputs) + else: + camera_tokens: List[torch.Tensor] = [ + images_forward( + { + "siglip": pixel_values[f"{camera_name}.siglip"], + "depth": pixel_values[f"{camera_name}.depth"], + } + ) + for camera_name in camera_names + ] + image_tokens = torch.cat(camera_tokens, dim=-2) + return image_tokens + + def _get_image_features_add(self, pixel_values: dict[str, torch.Tensor]) -> torch.Tensor: + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values: images of shape `[B * num_images, C, H, W]`. + Keys in the dict correspond to the specific vision encoder to use + Returns: + image_features of shape [B * num_images, num_tokens, token_size = 2048] + """ + siglip_input = pixel_values["siglip"] + depth_input = pixel_values["depth"] + siglip_output: ViTOutput = self.vision_tower(siglip_input) + siglip_features = siglip_output.last_hidden_state + depth_output: list[torch.Tensor] = self.depth_tower(depth_input) + depth_features = torch.cat(depth_output, dim=-1) if len(depth_output) > 1 else depth_output[0] + siglip_features = self.siglip_projector(siglip_features) + depth_features = self.depth_projector(depth_features) + if self.config.depth_only: + image_features = depth_features + elif self.training and torch.bernoulli(torch.tensor(self.config.mask_prob), generator=self.generator): + num_tokens = depth_features.shape[1] + device = depth_features.device + ones = torch.ones((depth_features.shape[0], num_tokens), device=device) + indices = torch.multinomial(ones / num_tokens, num_samples=num_tokens // 2) + mask = ones.to(dtype=torch.bool).scatter_(dim=-1, index=indices, value=0) + mask = mask.unsqueeze(-1) + image_features = siglip_features * mask + depth_features * ~mask + else: + image_features = (siglip_features + depth_features) / 2 + image_features = image_features / self.config.text_config.hidden_size**0.5 + return image_features + + def _get_image_features_spatial(self, pixel_values: dict[str, torch.Tensor]) -> torch.Tensor: + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values: images of shape `[B * num_images, C, H, W]`. + Keys in the dict correspond to the specific vision encoder to use + Returns: + image_features of shape [B * num_images, num_tokens, token_size = 2048] + """ + siglip_input = pixel_values["siglip"] + depth_input = pixel_values["depth"] + siglip_output: ViTOutput = self.vision_tower(siglip_input) + siglip_features = siglip_output.last_hidden_state + depth_output: dict[str, torch.Tensor] = self.depth_tower(depth_input) + points = depth_output["points"] + mask = depth_output["mask"] + mask_binary = mask > self.depth_tower.mask_threshold + points = torch.where(mask_binary, points, 0) + points_embed: torch.Tensor = self.depth_projector(points) + if self.config.projection == "spatial_concat": + features = torch.cat([siglip_features, points_embed], dim=-1) + image_features = self.projector(features) + else: + features = siglip_features + points_embed + image_features = self.siglip_projector(features) + image_features = image_features / self.config.text_config.hidden_size**0.5 + return image_features + + @classmethod + def from_pretrained(cls, *args, **kwargs) -> "PaliGemma3D": + model = super().from_pretrained(*args, **kwargs) + model._load_depth_model_state_dict(model.depth_tower) + return model + + +class RobotStateProjector(ConfigurableModule): + """Pack robot state and project to a single token per timestep""" + + def __init__(self, config: RobotStateProjectorConfig): + super().__init__(config) + if self.config.fourier: + raise NotImplementedError("Fourier robot state projector is not implemented yet") + + self.robot_state_tokens_proj = make_mlp( + layer_sizes=self.config.layers, + activation=self.config.activation, + norm=torch.nn.LayerNorm, + ) + + def forward(self, inputs: RoboticsInput) -> Optional[torch.Tensor]: + """ + Returns: + torch.Tensor of shape [B, num_past_steps, token_size] or None (if mode == 'none') + """ + if self.config.mode == "ee_pose": + robot_state = torch.cat([inputs.ee_pose_translation, inputs.ee_pose_rotation], dim=-1) + elif self.config.mode == "ee_pose_gripper": + robot_state = torch.cat( + [inputs.ee_pose_translation, inputs.ee_pose_rotation, inputs.gripper], + dim=-1, + ) + elif self.config.mode == "ee_pose_joints": + robot_state = torch.cat( + [inputs.ee_pose_translation, inputs.ee_pose_rotation, inputs.joints], + dim=-1, + ) + elif self.config.mode == "joints": + robot_state = inputs.joints + elif self.config.mode == "all": + robot_state = torch.cat( + [ + inputs.ee_pose_translation, + inputs.ee_pose_rotation, + inputs.gripper, + inputs.joints, + ], + dim=-1, + ) + elif self.config.mode == "none": + robot_state = torch.tensor([], device=inputs.ee_pose_translation.device).view( + inputs.ee_pose_translation.shape[0], + 0, + self.config.layers[0] if len(self.config.layers) > 0 else 0, + ) + else: + raise NotImplementedError(f"Unknown image tokens mode {self.config.mode}") + output = self.robot_state_tokens_proj(robot_state) + return output + + +class FourierFeatures(ConfigurableModule): + def __init__(self, config: FourierFeaturesConfig): + super().__init__(config) + if self.config.learnable_features: + self.linear = torch.nn.Linear( + in_features=1, out_features=self.config.num_features // 2, bias=False + ) + else: + half_dim = self.config.num_features // 2 + freqs = torch.log(torch.tensor(self.config.max_period)) / (half_dim - 1) + freqs = torch.exp(-freqs * torch.arange(half_dim)) + self.register_buffer("freqs", freqs) + self.layers: torch.nn.Sequential = make_mlp( + self.config.layers, + activation=self.config.activation, + norm=self.config.norm, + activate_final=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Compute Fourier features and project them via MLP + Args: + x: Input tensor of shape [..., 1] + Returns: + torch.Tensor: Fourier features of shape [..., num_features] or [..., layers[-1]] + """ + assert x.shape[-1] == 1 and x.ndim > 1, x.shape + if self.config.learnable_features: + frequencies = 2 * math.pi * self.linear(x) + else: + frequencies = x * expand_dims(self.freqs, x.ndim, [-1, 1]) + output = torch.cat([torch.cos(frequencies), torch.sin(frequencies)], dim=-1) + output = self.layers(output) + return output + + +class NoisedControlProjector(ConfigurableModule): + """Pack noised control (translation, rotation, gripper) and project to a single token per timestep""" + + def __init__(self, config: NoisedControlProjectorConfig): + super().__init__(config) + self.input_projector = torch.nn.Linear( + in_features=self.config.layers[0], + out_features=self.config.layers[1] // 2, + bias=False, + ) + self.time_embed = FourierFeatures(self.config.time_embed) + self.layers = make_mlp( + self.config.layers[1:], + activation=self.config.activation, + norm=self.config.norm, + activate_final=False, + bias=False, + ) + + def forward(self, inputs: FlowInput | DiffusionInput) -> Optional[torch.Tensor]: + """ + Returns: + torch.Tensor of shape [B, num_control_timesteps, token_size] + """ + noised_controls = torch.cat([inputs.translation_t, inputs.rotation_t, inputs.gripper_t], dim=-1) + noised_controls = self.input_projector(noised_controls) + timestep = self.time_embed(inputs.timestep) + timestep = timestep.expand(-1, noised_controls.shape[1], -1) + features = torch.cat([timestep, noised_controls], dim=-1) + output = self.layers(features) + return output + + +def unmask_unattended(attn_mask: torch.Tensor, mask_value: Optional[float] = None) -> torch.Tensor: + """ + Copy-pased from `transformers.modeling_attn_mask_utils.AttentionMaskConverter._unmask_unattended` + + Attend to all tokens in fully-masked rows. This is required by F.scaled_dot_product_attention + memory-efficient attention path. Otherwise, results are NaN + Details: https://github.com/pytorch/pytorch/issues/110213 + + Args: + attn_mask: [B, 1 | num_heads, query_seq_len, kv_seq_len] or [B, query_seq_len, kv_seq_len], float dtype + mask_value: The value inside `attn_mask` that corresponds to masked elements + Returns: + + For example, if `attn_mask` is (e.g. here left-padding case) + ``` + [ + [[ + [0, 0, 0], + [0, 0, 0], + [0, 0, 1] + ]], + [[ + [1, 0, 0], + [1, 1, 0], + [1, 1, 1] + ]], + [[ + [0, 0, 0], + [0, 1, 0], + [0, 1, 1] + ]] + ] + ``` + then the modified `attn_mask` will be + ``` + [ + [[ + [1, 1, 1], <-- modified + [1, 1, 1], <-- modified + [0, 0, 1] + ]], + [[ + [1, 0, 0], + [1, 1, 0], + [1, 1, 1] + ]], + [[ + [1, 1, 1], <-- modified + [0, 1, 0], + [0, 1, 1] + ]] + ] + ``` + """ + assert attn_mask.dtype.is_floating_point, attn_mask.dtype + if mask_value is None: + mask_value = torch.finfo(attn_mask.dtype).min + return attn_mask * ~torch.all(attn_mask == mask_value, dim=-1, keepdim=True) + + +@torch.no_grad() +def attn_mask_to_float(attn_mask: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Convert a 4D mask of type bool to `dtype`. If the attn_mask isn't 4D or isn't bool, raise error + """ + assert attn_mask.ndim == 4, attn_mask.shape + assert attn_mask.dtype == torch.bool, attn_mask.dtype + if dtype is None: + dtype = torch.get_autocast_dtype(attn_mask.device.type) + mask_value = torch.finfo(dtype).min + attn_mask = torch.zeros(attn_mask.shape, dtype=dtype, device=attn_mask.device).masked_fill( + ~attn_mask, mask_value + ) + attn_mask = unmask_unattended(attn_mask, mask_value) + return attn_mask + + +@torch.no_grad() +def make_4d_float_attn_mask( + attn_mask: Optional[torch.Tensor], + query_seq_length: int, + kv_seq_length: int, + dtype: torch.dtype, + device: torch.device, + batch_size: int, +) -> torch.Tensor: + """ + Creates a 4D mask of shape [B | 1, 1, query_length, kv_seq_length] from a 2D mask of shape [B, kv_seq_length]. + If the input `attn_mask` is already 4D: if dtype=torch.bool, convert to dtype, else do nothing + If the input is None, output is a full bi-directional attn_mask + + Args: + attn_mask: A 2D attention mask of shape [B, kv_seq_length] or [B, 1, query_length, kv_seq_length] + and dtype bool. False values indicate masked out positions + query_seq_length: The query sequence length (L) + kv_seq_length: The key-value sequence length (S). When `transformers.StaticCache` is used, this should + equal the cache size to account for zero-padding the part of the cache that is not yet filled. + dtype: Output dtype + device: Output device + batch_size: Batch size + Returns: + torch.Tensor of shape [B | 1, 1, query_length, kv_seq_length] (i.e. [B | 1, 1, L, S]). + Contains zero at unmasked positions and `torch.finfo(dtype).min` at masked positions + """ + if attn_mask is not None and attn_mask.ndim == 4: + if attn_mask.dtype == torch.bool: + attn_mask = attn_mask_to_float(attn_mask, dtype=dtype) + elif attn_mask.dtype != dtype: + raise TypeError(f"Expected attn_mask.dtype={dtype}, but got {attn_mask.dtype}") + return attn_mask + mask_value = torch.finfo(dtype).min + output_mask = torch.zeros([batch_size, 1, query_seq_length, kv_seq_length], dtype=dtype, device=device) + if attn_mask is not None: + assert attn_mask.dtype == torch.bool, f"Unsupported dtype {attn_mask.dtype}" + mask_length = attn_mask.shape[-1] + if mask_length != kv_seq_length: + raise NotImplementedError(f"{mask_length} != {kv_seq_length} not properly supported yet") + inverted_mask = ~attn_mask.view(batch_size, 1, 1, mask_length) + output_mask[..., :mask_length] = output_mask[..., :mask_length].masked_fill(inverted_mask, mask_value) + return output_mask + + +class VLMInput(Protocol): + input_ids: torch.Tensor + attn_mask: torch.Tensor + images: Dict[str, torch.Tensor] + multimodal_indices: torch.Tensor + unimodal_indices: torch.Tensor + + @property + def inputs_embeds(self) -> Optional[torch.Tensor]: + return None + + @property + def past_key_values(self) -> Optional[List[torch.Tensor]]: + return None + + +def zero_out_param_pretrained_grad( + param: torch.nn.Parameter | torch.distributed.tensor.DTensor, + module: torch.nn.Module, +) -> None: + """Zero out the gradients of pretrained embeddings""" + module.mask = module.mask.to(param.device) + if isinstance(param, torch.distributed.tensor.DTensor) and not isinstance( + module.mask, torch.distributed.tensor.DTensor + ): + module.mask = torch.distributed.tensor.distribute_tensor( + module.mask, device_mesh=param.device_mesh, placements=param.placements + ) + mask = module.mask + if type(param) is torch.distributed.tensor.DTensor and type(mask) is torch.distributed.tensor.DTensor: + assert param.grad.shape == mask.shape, f"{param.grad.shape} != {mask.shape}" + param.grad._local_tensor = torch.where(mask._local_tensor, param.grad._local_tensor, 0) + elif type(param) in (torch.Tensor, torch.nn.Parameter) and type(mask) is torch.Tensor: + assert param.grad.shape == mask.shape, f"{param.grad.shape} != {mask.shape}" + param.grad = torch.where(mask, param.grad, 0) + elif type(param) in (torch.Tensor, torch.nn.Parameter) and type(mask) is torch.distributed.tensor.DTensor: + mask = mask.full_tensor() + assert param.grad.shape == mask.shape, f"{param.grad.shape} != {mask.shape}" + param.grad = torch.where(mask, param.grad, 0) + else: + raise ValueError(f"Unsupported parameter type: {type(param)} and mask type: {type(mask)}") + + +class PaliGemmaVLM(ConfigurableModule): + """Wraps PaliGemma to make compatible with VLM API""" + + def __init__(self, config: PaliGemmaVLMConfig): + super().__init__(config) + if self.config.with_depth: + config = PaliGemma3DConfig.from_pretrained( + self.config.model_id, **self.config.paligemma_3d_config_dict + ) + self.model = PaliGemma3D.from_pretrained( + self.config.model_id, + config=config, + attn_implementation=self.config.attn_implementation, + ) + else: + self.model = transformers.AutoModelForVision2Seq.from_pretrained( + self.config.model_id, + attn_implementation=self.config.attn_implementation, + ) + hf_processor = transformers.AutoProcessor.from_pretrained(self.config.model_id) + self.processor = PaliGemmaDepthProcessor( + config=self.config.processor_config, + hf_processor=hf_processor, + depth_tokens=self.config.depth_tokens, + ) + self._resize_siglip_image_input() + self._maybe_override_get_image_features() + if self.config.depth_tokens > 0: + self._resize_llm_token_embeddings(hf_processor.tokenizer) + if not self.config.lm_head: + self.model.language_model.lm_head = torch.nn.Identity() + self.model.train(True) + for decoder in self.model.language_model.model.layers: + decoder.self_attn.is_causal = False + + def forward( + self, + inputs: VLMInput, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> VLMOutput: + del kwargs + self._maybe_register_zero_out_grad_hooks() + cache = transformers.DynamicCache() + if inputs.attn_mask.ndim == 4: + attn_mask = attn_mask_to_float(inputs.attn_mask) + else: + attn_mask = inputs.attn_mask + images = { + encoder_camera_name: camera_images.view(-1, *camera_images.shape[2:]) + for (encoder_camera_name, camera_images) in inputs.images.items() + } + llm_output: transformers.models.paligemma.modeling_paligemma.PaliGemmaCausalLMOutputWithPast = ( + self.model( + input_ids=inputs.input_ids, + pixel_values=images, + attention_mask=attn_mask, + use_cache=use_cache, + past_key_values=cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + ) + indices = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device=inputs.input_ids.device) + image_indices = indices[inputs.input_ids[0] == self.processor.hf_processor.image_token_id] + text_indices = indices[inputs.input_ids[0] != self.processor.hf_processor.image_token_id] + output = VLMOutput( + llm_output=LLMOutput.from_transformers( + input_ids=inputs.input_ids, + llm_output=llm_output, + text_indices=text_indices, + image_indices=image_indices, + ), + vit_tokens=llm_output.image_hidden_states, + attn_mask=inputs.attn_mask, + ) + return output + + @property + def fsdp_wrap_modules(self) -> Set[torch.nn.Module]: + transformer_modules = { + module + for module in self.modules() + if isinstance( + module, + ( + transformers.models.siglip.modeling_siglip.SiglipEncoderLayer, + transformers.models.siglip.modeling_siglip.SiglipVisionTransformer, + timm.models.vision_transformer.Block, + timm.models.vision_transformer.VisionTransformer, + transformers.models.gemma.modeling_gemma.GemmaDecoderLayer, + ), + ) + or module + in ( + self.model.language_model.model.embed_tokens, + self.model.language_model.model.norm, + ) + } + if self.config.with_depth: + projectors = self.model.projectors + else: + projectors = {self.model.multi_modal_projector} + return projectors | transformer_modules + + def _resize_siglip_image_input(self) -> None: + """ + Enables resizing SigLIP positional embeddings to a new image size. + """ + num_image_tokens: int = self.config.processor_config.num_image_tokens["main"] + image_size: Dict[str, int] = self.config.processor_config.image_sizes["main"].as_json() + siglip_embeddings: transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings = ( + self.model.vision_tower.vision_model.embeddings + ) + embedding_weight: torch.Tensor = timm.layers.pos_embed.resample_abs_pos_embed( + posemb=siglip_embeddings.position_embedding.weight.unsqueeze(0), + new_size=(image_size["height"] // 14, image_size["width"] // 14), + old_size=( + siglip_embeddings.image_size // 14, + siglip_embeddings.image_size // 14, + ), + num_prefix_tokens=0, + interpolation="bicubic", + antialias=True, + verbose=False, + ).squeeze(0) + with torch.no_grad(): + siglip_embeddings.position_embedding.weight.data = embedding_weight + siglip_embeddings.num_patches = siglip_embeddings.num_positions = num_image_tokens + siglip_embeddings.image_size = dict(image_size) + siglip_embeddings.register_buffer( + "position_ids", + torch.arange(siglip_embeddings.num_positions).expand((1, -1)), + persistent=False, + ) + + def interpolate_pos_encoding(embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + del embeddings + new_size = (height // 14, width // 14) + old_size = (image_size["height"] // 14, image_size["width"] // 14) + if old_size == new_size: + patch_pos_embedding = siglip_embeddings.position_embedding.weight + else: + patch_pos_embedding: torch.Tensor = timm.layers.pos_embed.resample_abs_pos_embed( + posemb=siglip_embeddings.position_embedding.weight.unsqueeze(0), + new_size=new_size, + old_size=( + image_size["height"] // 14, + image_size["width"] // 14, + ), + num_prefix_tokens=0, + interpolation="bicubic", + antialias=True, + verbose=False, + ).squeeze(0) + return patch_pos_embedding + + siglip_embeddings.interpolate_pos_encoding = interpolate_pos_encoding + if not self.processor.config.is_single_image_size: + self.model.vision_tower.forward = functools.partial( + self.model.vision_tower.forward, interpolate_pos_encoding=True + ) + + def _maybe_override_get_image_features(self) -> None: + """ + Override PaliGemmaForConditionalGeneration.get_image_features() from transformers + such that it can handle multiple cameras with different resolutions. + """ + if self.config.with_depth: + return + images_forward = self.model.get_image_features + camera_names: List[str] = self.processor.config.camera_names + + def get_image_features(pixel_values: torch.Tensor) -> torch.Tensor: + if self.processor.config.is_single_image_size: + inputs = einops.rearrange( + torch.stack( + [pixel_values[f"{camera_name}.siglip"] for camera_name in camera_names], + dim=1, + ), + "B N C H W -> (B N) C H W", + ) + image_tokens = images_forward(inputs) + else: + camera_tokens: List[torch.Tensor] = [ + images_forward(pixel_values[f"{camera_name}.siglip"]) for camera_name in camera_names + ] + image_tokens = torch.cat(camera_tokens, dim=-2) + return image_tokens + + self.model.get_image_features = get_image_features + + def _resize_llm_token_embeddings(self, tokenizer: transformers.PreTrainedTokenizer) -> None: + assert self.config.depth_tokens > 0, self.config.depth_tokens + tokenizer.add_tokens([f"" for i in range(self.config.depth_tokens)]) + total_num_tokens = len(tokenizer) + vocab_size = tokenizer.vocab_size + llm = self.model.language_model + (_, hidden_size) = llm.lm_head.weight.shape + self.model.resize_token_embeddings( + total_num_tokens, + pad_to_multiple_of=64, + mean_resizing=self.config.mean_resizing, + ) + if self.config.train_only_depth_tokens: + assert len(self.model.language_model._tied_weights_keys) > 0 + (weight_size, hidden_size) = llm.lm_head.weight.shape + mask = torch.cat( + [ + torch.zeros([vocab_size, hidden_size], dtype=torch.bool), + torch.ones([weight_size - vocab_size, hidden_size], dtype=torch.bool), + ], + dim=0, + ) + self.mask = mask + self.embed_handle = None + self.lm_head_handle = None + self._maybe_register_zero_out_grad_hooks() + + def _maybe_register_zero_out_grad_hooks(self) -> None: + """ + Register hooks to zero out the gradients of pretrained embeddings and LM head. + Skips registering the hooks if they already exist. This runs at every step as wrapping + in FSDP removes any hooks that were registered on the *parameters* of the original module + and the only way to run this reliably is to check if the hooks exist. + """ + if not self.config.train_only_depth_tokens: + return + llm = self.model.language_model + if ( + self.embed_handle is None + or llm.model.embed_tokens.weight._post_accumulate_grad_hooks is None + or self.embed_handle.id not in llm.model.embed_tokens.weight._post_accumulate_grad_hooks + ): + self.embed_handle = llm.model.embed_tokens.weight.register_post_accumulate_grad_hook( + functools.partial(zero_out_param_pretrained_grad, module=self) + ) + if llm.model.embed_tokens.weight is not llm.lm_head.weight and ( + self.lm_head_handle is None + or llm.lm_head.weight._post_accumulate_grad_hooks is None + or self.lm_head_handle.id not in llm.lm_head.weight._post_accumulate_grad_hooks + ): + self.lm_head_handle = llm.lm_head.weight.register_post_accumulate_grad_hook( + functools.partial(zero_out_param_pretrained_grad, module=self) + ) + + +def make_position_indices( + position_indices: Optional[torch.Tensor], + seq_length: int, + device: torch.device, + max_seq_length: Optional[int], +) -> torch.Tensor: + if position_indices is not None: + position_indices = position_indices.to(dtype=torch.int64) + else: + position_indices = torch.arange(seq_length, dtype=torch.int64, device=device).view(1, -1) + if not torch.max(position_indices) < max_seq_length: + raise IndexError( + f"position_indices={position_indices} contains index out of bounds of num_embeddings={max_seq_length}" + ) + return position_indices + + +class RotaryPositionalEncoding(ConfigurableModule): + """ + Rotary Positional Embeddings (RoPE) from https://arxiv.org/abs/2104.09864 + Reference implementations: + - https://github.com/meta-llama/llama/blob/main/llama/model.py#L80 + - transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding + - transformers.models.llama.modeling_llama.LlamaRotaryEmbedding + + If cached=True, we cache the embeddings for each position up to `num_embeddings` + """ + + def __init__(self, config: RotaryPositionalEncodingConfig): + super().__init__(config) + inv_freq = 1.0 / self.config.base ** ( + torch.arange(0, self.config.embedding_dim, 2, dtype=torch.float32) / self.config.embedding_dim + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._build_cache() + + def _build_cache(self) -> None: + if not self.config.cached: + return + position_indices = torch.arange(self.config.num_embeddings, dtype=torch.float32) + indices_inv_freq = torch.einsum("i, j -> ij", position_indices, self.inv_freq) + sin = torch.sin(indices_inv_freq) + cos = torch.cos(indices_inv_freq) + self.register_buffer("sin_cache", sin, persistent=False) + self.register_buffer("cos_cache", cos, persistent=False) + + def forward( + self, + tokens: torch.Tensor, + position_indices: Optional[torch.Tensor] = None, + apply: bool = True, + ) -> torch.Tensor: + """ + Args: + tokens: torch.Tensor of shape [B, ..., S, head_dim], where `...` might be any number of dims + position_indices: torch.Tensor of shape [B | 1, S]. The indices of tokens within the sequence + apply: If True, apply the positional embedding on tokens and return the result + Returns: + torch.Tensor of the same shape as `tokens` with positional embedding applied on tokens + """ + assert apply, f"{self.__class__} does not support applying embeddings externally" + position_indices = make_position_indices( + position_indices, + seq_length=tokens.shape[-2], + device=tokens.device, + max_seq_length=self.config.num_embeddings, + ) + if self.config.cached: + sin = self.sin_cache[position_indices] + cos = self.cos_cache[position_indices] + sin = torch.cat([sin, sin], dim=-1) + cos = torch.cat([cos, cos], dim=-1) + else: + inv_freq = self.inv_freq.view(1, -1, 1).to(dtype=torch.float32) + position_indices = position_indices.to(dtype=torch.float32).unsqueeze(1) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="In CPU autocast, but the target dtype is not supported. Disabling autocast.", + ) + with torch.autocast(device_type=tokens.device.type, dtype=torch.float32): + freqs = (inv_freq @ position_indices).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + (sin, cos) = (torch.sin(emb), torch.cos(emb)) + (sin, cos) = (sin.to(dtype=tokens.dtype), cos.to(dtype=tokens.dtype)) + sin = expand_dims(sin, tokens.ndim, order=[1, -1, 1, 1]) + cos = expand_dims(cos, tokens.ndim, order=[1, -1, 1, 1]) + tokens = tokens * cos + self._rotate_invert_half(tokens) * sin + return tokens + + @staticmethod + def _rotate_invert_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +EAGER_ATTN = "eager" + +SDPA_ATTN = "sdpa" + +FLASH_ATTN = "flash_attention_2" + + +def is_full_attn(attn_mask: Optional[torch.Tensor]) -> bool: + """ + Return True if attn_mask doesn't contain any masked out positions, False otherwise + """ + if attn_mask is None: + return True + if attn_mask.dtype == torch.bool: + return torch.all(attn_mask == 1).item() + if attn_mask.dtype.is_floating_point: + return torch.all(attn_mask == 0).item() + raise TypeError(f"Unrecognized dtype {attn_mask.dtype}") + + +@torch.no_grad() +def make_attn_mask_causal(attn_mask: torch.Tensor, cache_position: torch.Tensor) -> torch.Tensor: + """ + Args: + attn_mask: 4D tensor of shape [B | 1, 1, query_seq_len, kv_seq_len] (i.e. [B | 1, 1, L, S]) of float + dtype (NOT bool!). Masked positions contain the value `torch.finfo(dtype).min` + cache_position: torch.Tensor of type torch.int64 and shape [query_seq_len]. Contained values + are index positions of the query tokens in the sequence. During training, this would usually + be torch.arange(query_seq_len), but during generate, this would usually be a tensor sequence + with 1 element indicating the position of the token currently being generated + Returns: + torch.Tensor of the same shape as attn_mask. Contains zero at unmasked positions and + `torch.finfo(dtype).min` at masked positions + """ + if attn_mask.dtype.is_floating_point: + mask_value = torch.finfo(attn_mask.dtype).min + elif attn_mask.dtype == torch.bool: + mask_value = 0 + else: + raise TypeError(f"Unsupported mask type {attn_mask.dtype}") + (_, _, query_seq_length, kv_seq_length) = attn_mask.shape + causal_mask = torch.ones(attn_mask.shape, dtype=torch.bool, device=attn_mask.device) + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask = causal_mask * ( + torch.arange(kv_seq_length, device=cache_position.device).view(1, -1) > cache_position.view(-1, 1) + ).view(*[1] * (causal_mask.ndim - 2), query_seq_length, kv_seq_length) + causal_attn_mask = attn_mask.masked_fill_(causal_mask, mask_value) + return causal_attn_mask + + +def update_attn_mask( + attn_mask: Optional[torch.Tensor], + attn_implementation: str, + query_seq_length: int, + kv_seq_length: int, + cache_position: Optional[torch.Tensor], + cache: Optional[transformers.Cache], + batch_size: int, + causal: bool, + dtype: torch.dtype, + device: torch.device, + output_attentions: bool = False, +) -> Optional[torch.Tensor]: + """ + Update attn_mask such that it's compatible with the attention implementation. + Meant to be used with barrel.components.nn.layers.attention.MultiheadAttention and its derivatives + + Args: + attn_mask: dtype torch.bool, torch.float32, torch.float16 or torch.bfloat16 and shape one of: + - [B, kv_seq_length] (i.e. [B, S]) + - [B, 1, query_seq_length, kv_seq_length] (i.e. [B, 1, L, S]) + - [1, 1, query_seq_length, kv_seq_length] (i.e. [L, S]) + If bool, False values indicate masked positions. + If float, must contain only 0.0 and torch.finfo(dtype).min + If attn_mask is None, full-bidirectional attention is assumed. The output might be None or + a tensor. Refer to the return value documentation + attn_implementation: One of [FLASH_ATTN, SDPA_ATTN, EAGER_ATTN] + query_seq_length: The query sequence length (L) + kv_seq_length: The key-value sequence length (S) + cache_position: dtype torch.int64, shape [query_seq_len]. Used only when causal=True. + Contained values are index positions of the query tokens in the sequence. During training, + this would usually be torch.arange(query_seq_len), but during generate, this would usually be + a tensor sequence with 1 element indicating the position of the token currently being generated. + If None, default `cache_positions` are autocomputed from `query_seq_length` and cache size + cache: Optional cache. Usually not None when running generate at inference. + batch_size: Batch size of the generated attention mask + causal: If True, make the attn_mask causal -> all non-causal positions are masked out, regardless + of their attn_mask values. When using flash attention or SDPA and `causal == False`, make sure + to pass `causal` to the attention operation, in case this function delegates causal masking + dtype: dtype of the output attention mask. Must be the dtype of the attn computation + device: device of the output attention mask + output_attentions: If True, the attention operation is required to output attention weights + Returns: + - `None` in either of these cases: + - `attn_mask` doesn't contain any masked out positions and causal=False + - `attn_implementation in [FLASH_ATTN, SDPA_ATTN]` and `attn_mask` doesn't contain any + masked out positions. If causal=True, we instead rely on the causal argument to + flash attention or `torch.nn.functional.scaled_dot_product_attention`. This happens + only if the cache is empty and cache_position is None + - `attn_mask` if `attn_implementation == FLASH_ATTN` and `attn_mask` can't be ignored TODO(FLASH) + - torch.Tensor of shape [B, 1, query_length, kv_seq_length] (i.e. [B, 1, L, S]) and type `dtype`. + Contains zero at unmasked positions and `torch.finfo(dtype).min` at masked positions. + """ + assert attn_implementation in [FLASH_ATTN, SDPA_ATTN, EAGER_ATTN] + assert dtype.is_floating_point, dtype + if torch.jit.is_tracing() or torch.jit.is_scripting() or torch.compiler.is_compiling(): + raise NotImplementedError("Complete correctness not confirmed yet") + if isinstance(cache, transformers.StaticCache): + if attn_mask is not None and attn_mask.shape[-1] != cache.get_max_cache_shape(): + raise NotImplementedError("Complete correctness not confirmed yet") + full_attn = is_full_attn(attn_mask) + past_seen_tokens = cache.get_seq_length() if cache is not None else 0 + if full_attn and not causal: + return None + if ( + full_attn + and causal + and attn_implementation in [SDPA_ATTN, FLASH_ATTN] + and past_seen_tokens == 0 + and cache_position is None + ): + return None + past_seen_tokens = cache.get_seq_length() if cache is not None else 0 + static_cache = isinstance(cache, transformers.StaticCache) + if static_cache and kv_seq_length < cache.get_max_cache_shape(): + kv_seq_length = cache.get_max_cache_shape() + elif attn_mask is not None: + assert kv_seq_length == attn_mask.shape[-1], f"{kv_seq_length}, {attn_mask.shape}" + output_mask = make_4d_float_attn_mask( + attn_mask=attn_mask, + query_seq_length=query_seq_length, + kv_seq_length=kv_seq_length, + dtype=dtype, + device=device, + batch_size=batch_size, + ) + if causal: + cache_position = ( + torch.arange(past_seen_tokens, past_seen_tokens + query_seq_length, device=device) + if cache_position is None + else cache_position + ) + output_mask = make_attn_mask_causal(output_mask, cache_position) + if ( + attn_implementation == SDPA_ATTN + and attn_mask is not None + and attn_mask.device.type == "cuda" + and not output_attentions + ): + output_mask = unmask_unattended(output_mask, mask_value=torch.finfo(dtype).min) + return output_mask + + +def expand_kv_heads(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + The equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). Convert hidden_states from + [batch, num_kv_heads, seqlen, head_dim] -> [batch, num_attention_heads, seqlen, head_dim] + """ + (batch, num_kv_heads, slen, head_dim) = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim) + + +class MultiheadAttention(torch.nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper + + Different implementation from torch.nn.MultiheadAttention to support: + - Easy switch between EAGER_ATTN, SDPA_ATTN and FLASH_ATTN + - Number of key-value heads different from query heads + - Key-value cache during forward, in the same way as transformers. Useful for generation or + cross-attention to projected keys and values + - Ability to apply positional encodings to key and value after input linear projection + - Different linear projection output size + + Adapted from transformers.models.llama.modeling_llama.LlamaAttention + """ + + def __init__( + self, + in_features: int, + num_heads: int, + head_dim: Optional[int] = None, + out_features: Optional[int] = None, + key_features: Optional[int] = None, + value_features: Optional[int] = None, + num_kv_heads: Optional[int] = None, + bias: bool = False, + dropout: float = 0.0, + cache_layer: Optional[int] = None, + query_position_embed: Optional[Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]] = None, + key_position_embed: Optional[Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]] = None, + ): + """ + Args: + in_features: Input dimension for query linear projection. + num_heads: Number of heads for query + head_dim: Head dimension. If None, defaults to `in_features // num_heads` + out_features: Output dimension for the output linear layer. If None, defaults to `in_features` + key_features: Input dimension for key linear projection. If None, defaults to `in_features` + value_features: Input dimension for value linear projection. If None, defaults to `in_features` + num_kv_heads: Number of heads for keys and values. If None, defaults to `num_heads` + cache_layer: Index of the layer in the cache. Needed only when `cache` is passed to + the `forward()` call, usually during generation or when the projected keys and values need + to be cached during training. Can be omitted when `cache_layer` is passed to `forward` + position_embed: Callable that takes as input linearly projected query and key and a tuple of + positional embeddings and returns query and key with positional embeddings applied. Note + these embeddings are applied after linear projection. If you want to apply embeddings before + the linear projection, do so before calling the forward method and use the default value + for `position_embed`, which is a simple pass-through. Note you can also pass torch.nn.Module + key_position_embed: Callable that takes as input linearly projected key and optional positional + index in the sequence and returns key with positional embeddings applied. + positional embeddings and returns query and key with positional embeddings applied. Note + these embeddings are applied after linear projection. If you want to apply embeddings before + the linear projection, do so before calling the forward method and use the default value + for `position_embed`, which is a simple pass-through. Note you can also pass torch.nn.Module + """ + super().__init__() + self.in_features = in_features + self.key_features = key_features or in_features + self.value_features = value_features or in_features + self.bias = bias + self.out_features = out_features or in_features + self.num_heads = num_heads + self.head_dim = head_dim or in_features // num_heads + self.num_kv_heads = num_kv_heads or num_heads + self.dropout = dropout + self.query_position_embed = query_position_embed + self.key_position_embed = key_position_embed + self.cache_layer = cache_layer + self.q_proj = torch.nn.Linear(self.in_features, self.num_heads * self.head_dim, bias=self.bias) + self.k_proj = torch.nn.Linear(self.key_features, self.num_kv_heads * self.head_dim, bias=self.bias) + self.v_proj = torch.nn.Linear(self.value_features, self.num_kv_heads * self.head_dim, bias=self.bias) + self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.out_features, bias=self.bias) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + query_position_indices: Optional[torch.Tensor] = None, + key_position_indices: Optional[torch.Tensor] = None, + cache: Optional[transformers.Cache] = None, + cache_layer: Optional[int] = None, + output_attentions: bool = False, + cache_kwargs: Dict[str, Any] = {}, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + query: Query embedding of shape [B, L, in_features] + key: Key embedding of shape [B, S, key_features] + value: Value embedding of shape [B, S, value_features] + attn_mask: dtype torch.bool or same dtype as query/key/value and shape one of: + - [B, S] + - [B | 1, 1 | num_heads, L, S] + If bool, False values indicate masked positions (opposite of torch.nn.MultiheadAttention) + If float, must contain only 0.0 and torch.finfo(dtype).min + If attn_mask is None, full-bidirectional attention or causal attention is used depdening + on the value of `is_causal`. + is_causal: If True, apply additional causal masking to `attn_mask` + query_position_indices: torch.Tensor of shape [1 | B, L] containing the indices of the `query` + tokens within the entire sequence. Passed through to query_position_embed. If None and `cache` + is not None, indices are autogenerated [0, 1, ..., L] and offset by `cache_size` + key_position_indices: Same as `query_position_indices`, but applied to key + cache: transformers.Cache containing cached key-value pairs. The linearly projected + `key` and `value` passed to this function get added to the cache and concatenated after the + key-value pairs in the cache and then attention is computed on the concatenated sequence. + This is most commonly used at inference when generating auto-regressively or when one needs + to cross attend to the keys and values outside this module forward pass. + cache_layer: Index of the layer in the cache. Needed only when `cache` is passed to + the `forward()` call, usually during generation or when the projected keys and values need + to be cached during training. Can be omitted when `cache_layer` was passed to `__init__` + output_attentions: If True, output also the attention weights. Otherwise output None. + Note that only the eager implementation of MultiheadAttention supports this. + cache_kwargs: kwargs directly passed to `cache.update()` + Returns: + Tuple with entries: + - Attention block output: torch.Tensor of shape [B, L, out_features] + - Optional attention weights if `output_attentions=True`, shape [B, num_heads, L, S] + """ + batch_size = query.shape[0] + query_states = self.q_proj(query) + key_states = self.k_proj(key) + value_states = self.v_proj(value) + query_states = query_states.view( + batch_size, query_states.shape[1], self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + batch_size, key_states.shape[1], self.num_kv_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + batch_size, value_states.shape[1], self.num_kv_heads, self.head_dim + ).transpose(1, 2) + (query_states, key_states) = self._maybe_apply_positional_embeddings( + query_states=query_states, + key_states=key_states, + query_position_indices=query_position_indices, + key_position_indices=key_position_indices, + cache=cache, + ) + (key_states, value_states) = self._maybe_update_cache( + key_states, + value_states, + cache_layer=cache_layer, + cache=cache, + cache_kwargs=cache_kwargs, + ) + key_states = expand_kv_heads(key_states, self.num_heads // self.num_kv_heads) + value_states = expand_kv_heads(value_states, self.num_heads // self.num_kv_heads) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_mask = update_attn_mask( + attn_mask, + attn_implementation=EAGER_ATTN, + query_seq_length=query_states.shape[2], + kv_seq_length=value_states.shape[2], + cache_position=query_position_indices, + cache=cache, + batch_size=batch_size, + causal=is_causal, + dtype=query_states.dtype, + device=query_states.device, + output_attentions=output_attentions, + ) + if attn_mask is not None: + attn_mask = attn_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + attn_mask + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + shape = (batch_size, self.num_heads, query.shape[1], self.head_dim) + assert attn_output.shape == shape, f"{attn_output.shape} != {shape}" + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim) + attn_output = self.o_proj(attn_output) + if not output_attentions: + attn_weights = None + return attn_output, attn_weights + + def _maybe_apply_positional_embeddings( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + query_position_indices: Optional[torch.Tensor], + key_position_indices: Optional[torch.Tensor], + cache: Optional[transformers.Cache], + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = query_states.device + if self.query_position_embed is not None: + if query_position_indices is None and cache is not None: + query_position_indices = ( + torch.arange(query_states.shape[-2], dtype=torch.int64, device=device).view(1, -1) + + cache.get_seq_length() + ) + query_states = self.query_position_embed(query_states, position_indices=query_position_indices) + if self.key_position_embed is not None: + if key_position_indices is None and cache is not None: + key_position_indices = ( + torch.arange(key_states.shape[-2], dtype=torch.int64, device=device).view(1, -1) + + cache.get_seq_length() + ) + key_states = self.key_position_embed(key_states, position_indices=key_position_indices) + return query_states, key_states + + def _maybe_update_cache( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_layer: Optional[int], + cache: Optional[transformers.Cache], + cache_kwargs: Dict[str, Any], + ) -> Tuple[torch.Tensor, torch.Tensor]: + if cache is not None: + if cache_layer is None and self.cache_layer is None: + raise RuntimeError("When cache != None, cache_layer has to be set") + cache_layer = cache_layer if cache_layer is not None else self.cache_layer + (key_states, value_states) = cache.update(key_states, value_states, cache_layer, cache_kwargs) + return key_states, value_states + + +class MultiheadFlashAttention2(MultiheadAttention): + """ + MultiheadAttention implemented using flash attention module. Inherits `MultiheadAttention` as the weights + of the module stay untouched. The only change is on the forward pass where we call flash attention. + """ + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + query_position_indices: Optional[torch.Tensor] = None, + key_position_indices: Optional[torch.Tensor] = None, + cache: Optional[transformers.Cache] = None, + cache_layer: Optional[int] = None, + output_attentions: bool = False, + cache_kwargs: Dict[str, Any] = {}, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + query: Query embedding of shape [B, L, in_features] + key: Key embedding of shape [B, S, key_features] + value: Value embedding of shape [B, S, value_features] + attn_mask: dtype torch.bool and shape [B, S]. + If bool, False values indicate masked positions (opposite of torch.nn.MultiheadAttention) + If attn_mask is None, full-bidirectional attention or causal attention is used depdening + on the value of `is_causal`. + NOTE: Doesn't support 4D attn_mask, unlike MultiheadAttention + is_causal: If True, apply additional causal masking to `attn_mask` + query_position_indices: torch.Tensor of shape [1 | B, L] containing the indices of the `query` + tokens within the entire sequence. Passed through to query_position_embed. If None and `cache` + is not None, indices are autogenerated [0, 1, ..., L] and offset by `cache_size` + key_position_indices: Same as `query_position_indices`, but applied to key + cache: transformers.Cache containing cached key-value pairs. The linearly projected + `key` and `value` passed to this function get added to the cache and concatenated after the + key-value pairs in the cache and then attention is computed on the concatenated sequence. + This is most commonly used at inference when generating auto-regressively or when one needs + to cross attend to the keys and values outside this module forward pass. + cache_layer: Index of the layer in the cache. Needed only when `cache` is passed to + the `forward()` call, usually during generation or when the projected keys and values need + to be cached during training. Can be omitted when `cache_layer` was passed to `__init__` + output_attentions: If True, output also the attention weights. Otherwise output None. + Note that only the eager implementation of MultiheadAttention supports this. + cache_kwargs: kwargs directly passed to `cache.update()` + Returns: + Tuple with entries: + - Attention block output: torch.Tensor of shape [B, L, out_features] + - Optional attention weights if `output_attentions=True`, shape [B, num_heads, L, S] + """ + if isinstance(cache, transformers.StaticCache): + raise ValueError( + "transformers.StaticCache not compatible with flash attention. Use `sdpa` instead (for now)." + ) + assert output_attentions is False, f"{self.__class__} doesn't support output_attentions=True" + batch_size = query.shape[0] + query_states = self.q_proj(query) + key_states = self.k_proj(key) + value_states = self.v_proj(value) + query_states = query_states.view( + batch_size, query_states.shape[1], self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + batch_size, key_states.shape[1], self.num_kv_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + batch_size, value_states.shape[1], self.num_kv_heads, self.head_dim + ).transpose(1, 2) + (query_states, key_states) = self._maybe_apply_positional_embeddings( + query_states=query_states, + key_states=key_states, + query_position_indices=query_position_indices, + key_position_indices=key_position_indices, + cache=cache, + ) + (key_states, value_states) = self._maybe_update_cache( + key_states, + value_states, + cache_layer=cache_layer, + cache=cache, + cache_kwargs=cache_kwargs, + ) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + attn_mask = update_attn_mask( + attn_mask, + attn_implementation=FLASH_ATTN, + query_seq_length=query_states.shape[2], + kv_seq_length=value_states.shape[2], + cache_position=query_position_indices, + cache=cache, + batch_size=batch_size, + causal=is_causal, + dtype=query_states.dtype, + device=query_states.device, + output_attentions=output_attentions, + ) + raise NotImplementedError("Correctness not yet confirmed") + attn_output = transformers.modeling_flash_attention_utils._flash_attention_forward( + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attn_mask, + query_length=query.shape[1], + position_ids=None, + dropout=self.dropout if self.training else 0.0, + sliding_window=None, + use_top_left_mask=False, + is_causal=is_causal, + deterministic=True, + ) + size = (batch_size, self.num_heads, query.shape[1], self.head_dim) + if attn_output.size() != size: + raise ValueError(f"`attn_output` should be of size {size}, but is {attn_output.size()}") + shape = (batch_size, self.num_heads, query.shape[1], self.head_dim) + assert attn_output.shape == shape, f"{attn_output.shape} != {shape}" + attn_output = attn_output.reshape(batch_size, query.shape[1], -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, None + + +class MultiheadSdpaAttention(MultiheadAttention): + """ + MultiheadAttention SDPA attention. Inherits `MultiheadAttention` as the weights of the module stay untouched. + The only change is on the forward pass where we call `torch.nn.functional.scaled_dot_product_attention` + """ + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + query_position_indices: Optional[torch.Tensor] = None, + key_position_indices: Optional[torch.Tensor] = None, + cache: Optional[transformers.Cache] = None, + cache_layer: Optional[int] = None, + output_attentions: bool = False, + cache_kwargs: Dict[str, Any] = {}, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + query: Query embedding of shape [B, L, in_features] + key: Key embedding of shape [B, S, key_features] + value: Value embedding of shape [B, S, value_features] + attn_mask: dtype torch.bool or same dtype as query/key/value and shape one of: + - [B, S] + - [B | 1, 1 | num_heads, L, S] + If bool, False values indicate masked positions (opposite of torch.nn.MultiheadAttention) + If float, must contain only 0.0 and torch.finfo(dtype).min + If attn_mask is None, full-bidirectional attention or causal attention is used depdening + on the value of `is_causal`. + is_causal: If True, apply additional causal masking to `attn_mask` + query_position_indices: torch.Tensor of shape [1 | B, L] containing the indices of the `query` + tokens within the entire sequence. Passed through to query_position_embed. If None and `cache` + is not None, indices are autogenerated [0, 1, ..., L] and offset by `cache_size` + key_position_indices: Same as `query_position_indices`, but applied to key + cache: transformers.Cache containing cached key-value pairs. The linearly projected + `key` and `value` passed to this function get added to the cache and concatenated after the + key-value pairs in the cache and then attention is computed on the concatenated sequence. + This is most commonly used at inference when generating auto-regressively or when one needs + to cross attend to the keys and values outside this module forward pass. + cache_layer: Index of the layer in the cache. Needed only when `cache` is passed to + the `forward()` call, usually during generation or when the projected keys and values need + to be cached during training. Can be omitted when `cache_layer` was passed to `__init__` + output_attentions: If True, output also the attention weights. Otherwise output None. + Note that only the eager implementation of MultiheadAttention supports this. + cache_kwargs: kwargs directly passed to `cache.update()` + Returns: + Tuple with entries: + - Attention block output: torch.Tensor of shape [B, L, out_features] + - Optional attention weights if `output_attentions=True`, shape [B, num_heads, L, S] + """ + assert output_attentions is False, f"{self.__class__} doesn't support output_attentions=True" + batch_size = query.shape[0] + query_states = self.q_proj(query) + key_states = self.k_proj(key) + value_states = self.v_proj(value) + query_states = query_states.view( + batch_size, query_states.shape[1], self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + batch_size, key_states.shape[1], self.num_kv_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + batch_size, value_states.shape[1], self.num_kv_heads, self.head_dim + ).transpose(1, 2) + (query_states, key_states) = self._maybe_apply_positional_embeddings( + query_states=query_states, + key_states=key_states, + query_position_indices=query_position_indices, + key_position_indices=key_position_indices, + cache=cache, + ) + (key_states, value_states) = self._maybe_update_cache( + key_states, + value_states, + cache_layer=cache_layer, + cache=cache, + cache_kwargs=cache_kwargs, + ) + key_states = expand_kv_heads(key_states, self.num_heads // self.num_kv_heads) + value_states = expand_kv_heads(value_states, self.num_heads // self.num_kv_heads) + attn_mask = update_attn_mask( + attn_mask, + attn_implementation=SDPA_ATTN, + query_seq_length=query_states.shape[2], + kv_seq_length=value_states.shape[2], + cache_position=query_position_indices, + cache=cache, + batch_size=batch_size, + causal=is_causal, + dtype=query_states.dtype, + device=query_states.device, + output_attentions=output_attentions, + ) + if attn_mask is not None: + attn_mask = attn_mask[:, :, :, : key_states.shape[-2]] + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attn_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + shape = (batch_size, self.num_heads, query.shape[1], self.head_dim) + assert attn_output.shape == shape, f"{attn_output.shape} != {shape}" + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, query.shape[1], self.num_heads * self.head_dim) + attn_output = self.o_proj(attn_output) + return attn_output, None + + +ATTN_TYPES = { + EAGER_ATTN: MultiheadAttention, + SDPA_ATTN: MultiheadSdpaAttention, + FLASH_ATTN: MultiheadFlashAttention2, +} + + +def make_activation(activation: str | Type[torch.nn.Module], **kwargs) -> torch.nn.Module: + if isinstance(activation, str): + TorchActivation: Type[torch.nn.Module] = getattr(torch.nn, activation) + else: + TorchActivation: Type[torch.nn.Module] = activation + assert issubclass(TorchActivation, torch.nn.Module), TorchActivation + return TorchActivation(**kwargs) + + +class PiZeroMLP(torch.nn.Module): + def __init__(self, feature_size: int, hidden_size: int, activation: str): + super().__init__() + self.gate_proj = torch.nn.Linear(feature_size, hidden_size, bias=False) + self.up_proj = torch.nn.Linear(feature_size, hidden_size, bias=False) + self.down_proj = torch.nn.Linear(hidden_size, feature_size, bias=False) + self.activation = make_activation(activation, approximate="tanh") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x)) + + +class PiZeroFlowMatchingDecoderBlock(ConfigurableModule): + def __init__(self, config: PiZeroFlowMatchingDecoderBlockConfig, **attn_kwargs): + super().__init__(config) + self.norm_in = GemmaRMSNorm(self.config.feature_size, eps=1e-06) + self.self_attn = ATTN_TYPES[self.config.attn_implementation]( + in_features=self.config.feature_size, + num_heads=self.config.num_heads, + head_dim=self.config.head_dim, + num_kv_heads=self.config.num_kv_heads, + **attn_kwargs, + ) + self.mlp = PiZeroMLP( + feature_size=self.config.feature_size, + hidden_size=self.config.hidden_size, + activation=self.config.activation, + ) + self.norm_out = GemmaRMSNorm(self.config.feature_size, eps=1e-06) + + def forward( + self, + query: torch.Tensor, + attn_mask: torch.Tensor, + cache: transformers.Cache, + attn_kwargs: Dict[str, Any], + ) -> torch.Tensor: + """ + Args: + query: torch.Tensor of shape [B, L, token_size]. The query seqence in the order: + [noised query tokens, condition token, robot state tokens] + timestep: torch.Tensor of shape [B, 1, token_size]. Timestep token + attn_mask: torch.Tensor of shape [B, 1, L, L+S] and dtype torch.bool, where S is the VLM + sequence length + cache: Cache that contains only the VLM tokens during training and VLM + past query tokens + during generation + num_noised_tokens: Number of noised tokens in `query` + num_condition_tokens: Number of condition tokens in `query` + Returns: + torch.Tensor of same shape as query [B, L, token_size] + """ + residual = x = query + x = self.norm_in(x) + (x, _) = self.self_attn( + query=x, + key=x, + value=x, + attn_mask=attn_mask, + is_causal=False, + cache=cache, + **attn_kwargs, + ) + x = residual + x + residual = x + x = self.norm_out(x) + x = self.mlp(x) + x = residual + x + return x + + +class PiZeroFlowMatchingDecoder(ConfigurableModule): + """PiZero Flow Matching control decoder""" + + def __init__(self, config: PiZeroFlowMatchingDecoderConfig): + super().__init__(config) + query_position_embed = RotaryPositionalEncoding(config=self.config.block_config.position_embed_config) + key_position_embed = RotaryPositionalEncoding(config=self.config.block_config.position_embed_config) + self.blocks = torch.nn.ModuleList( + [ + PiZeroFlowMatchingDecoderBlock( + self.config.block_config, + query_position_embed=query_position_embed, + key_position_embed=key_position_embed, + cache_layer=i, + ) + for i in range(self.config.num_blocks) + ] + ) + self.norm = GemmaRMSNorm(self.config.block_config.feature_size, eps=1e-06) + + def forward( + self, + control_tokens: torch.Tensor, + robot_state_tokens: torch.Tensor, + llm_kv_tokens: List[Tuple[torch.Tensor, torch.Tensor]], + attn_mask: Optional[torch.Tensor], + cache: Optional[transformers.Cache] = None, + ) -> torch.Tensor: + """ + Args: + control_tokens: torch.Tensor of shape [B, N, token_size], contains sequence of controls + robot_state_tokens: torch.Tensor of shape [B, num_state_tokens, token_size] + llm_kv_tokens: List of linearly projected key-value pairs from LLM, right before attention + operation. Each tensor is of the shape [B, num_kv_heads, S, head_dim] + attn_mask: One of + - shape [B, S], dtype torch.bool -> padding attention mask for LLM tokens + - shape [B, 1, L, S], dtype torch.bool -> full attention mask for LLM tokens + Returns: + torch.Tensor, shape [B, N, token_size] + """ + assert ( + len(llm_kv_tokens) == self.config.num_blocks + ), f"{len(llm_kv_tokens)} != {self.config.num_blocks}" + is_step_zero = cache.get_seq_length() == 0 if cache is not None else True + vlm_seq_len = attn_mask.shape[-1] + device = attn_mask.device + if cache is None: + cache = transformers.DynamicCache() + if is_step_zero: + position_indices = torch.arange(vlm_seq_len, dtype=torch.int64, device=device) + for block_index, kv_tokens in enumerate(llm_kv_tokens): + (key_states, value_states) = kv_tokens + cache.update( + key_states, + value_states, + block_index, + cache_kwargs={"cache_position": position_indices}, + ) + num_control_tokens = control_tokens.shape[1] + num_robot_state_tokens = robot_state_tokens.shape[1] + attn_mask = self._build_attn_mask( + num_control_tokens=num_control_tokens, + num_robot_state_tokens=num_robot_state_tokens, + attn_mask=attn_mask, + ) + if is_step_zero: + tokens = torch.cat([robot_state_tokens, control_tokens], axis=1) + query_position_indices = key_position_indices = vlm_seq_len + torch.arange( + tokens.shape[1], dtype=torch.int64, device=device + ).view(1, -1) + else: + tokens = control_tokens + attn_mask = attn_mask[:, :, -control_tokens.shape[1] :] + query_position_indices = key_position_indices = ( + vlm_seq_len + + num_robot_state_tokens + + torch.arange(tokens.shape[1], dtype=torch.int64, device=device).view(1, -1) + ) + for block in self.blocks: + tokens = block( + query=tokens, + attn_mask=attn_mask, + cache=cache, + attn_kwargs={ + "query_position_indices": query_position_indices, + "key_position_indices": key_position_indices, + "cache_kwargs": {"cache_position": key_position_indices.view(-1)}, + }, + ) + if is_step_zero: + (_, control_tokens) = torch.split(tokens, [num_robot_state_tokens, num_control_tokens], dim=1) + else: + control_tokens = tokens + control_tokens = self.norm(control_tokens) + return control_tokens + + @torch.no_grad() + def _build_attn_mask( + self, + num_control_tokens: int, + num_robot_state_tokens: int, + attn_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Expand `attn_mask` (which is effectively a padding mask) to 4D such that: + - robot state tokens and control tokens can't attend to padding tokens + - robot state tokens can't attend to control tokens + Note: We can't keep the mask in 2D as it doesn't allow masking of padding tokens from the + VLM sequence. Furthermore, in a 2D mask you can't disable attention from robot state tokens + to control tokens + """ + assert attn_mask.dtype == torch.bool, attn_mask.dtype + assert attn_mask.ndim in [2, 4], attn_mask.shape + device = attn_mask.device + batch_size = attn_mask.shape[0] + query_seq_len = num_robot_state_tokens + num_control_tokens + vlm_seq_len = attn_mask.shape[-1] + kv_seq_len = query_seq_len + vlm_seq_len + cross_attn_mask = torch.ones( + [batch_size, 1, query_seq_len, kv_seq_len], dtype=torch.bool, device=device + ) + if attn_mask.ndim == 2: + attn_mask = attn_mask.view(batch_size, 1, 1, vlm_seq_len) + else: + attn_mask = torch.any(attn_mask, dim=-2, keepdims=True) + cross_attn_mask[..., :vlm_seq_len] = attn_mask + robot_state_query_indices = torch.arange( + num_robot_state_tokens, dtype=torch.int64, device=device + ).view(-1, 1) + control_key_indices = ( + torch.arange(num_control_tokens, dtype=torch.int64, device=device).view(-1, 1) + + vlm_seq_len + + num_robot_state_tokens + ) + cross_attn_mask[:, :, robot_state_query_indices, control_key_indices] = 0 + return cross_attn_mask + + @property + def fsdp_wrap_modules(self) -> Set[torch.nn.Module]: + return {module for module in self.modules() if isinstance(module, type(self.blocks[0]))} | {self.norm} + + +def integrate_unitquat( + qt: torch.Tensor, + dq_dt: torch.Tensor, + dt: float | torch.Tensor, + body_frame: bool = True, + half_cover: bool = True, +) -> torch.Tensor: + """ + Integrate a unit quaternion `qt` by the derivative `dq_dt` over the time interval `dt`. + Args: + qt: Unit quaternion, shape [..., 4] + dq_dt: Derivative of the unit quaternion, shape [..., 4] + dt: Time interval to integrate over, scalar or a tensor of shape () or [..., 1] + half_cover: If True, the result is guaranteed to lie in the half space + body_frame: If True, the integration is done in the body frame (post-multiply), + otherwise in the inertial frame (pre-multiply). + Returns: + Integrated unit quaternion, shape [..., 4] + """ + assert qt.shape == dq_dt.shape, f"{qt.shape} != {dq_dt.shape}" + assert is_quaternion(qt), f"{qt.shape} not a quaternion" + if isinstance(dt, torch.Tensor): + assert dt.ndim in (0, qt.ndim), f"dt.ndim = {dt.ndim} | {qt.ndim}" + if body_frame: + omega_q = 2.0 * roma.quat_product(roma.quat_conjugation(qt), dq_dt) + else: + omega_q = 2.0 * roma.quat_product(dq_dt, roma.quat_conjugation(qt)) + omega = omega_q[..., :-1] + dq = roma.rotvec_to_unitquat(omega * dt) + if body_frame: + qt = roma.quat_product(qt, dq) + else: + qt = roma.quat_product(dq, qt) + if half_cover: + qt = quaternion_half_cover(qt) + return qt + + +def rotmat_inverse(rotation: torch.Tensor) -> torch.Tensor: + assert is_rotmat(rotation), f"Expected a rotation matrix, but got shape {rotation.shape}" + rotmat = rotmat_as_3x3(rotation) + rotmat = rotmat.transpose(-1, -2) + if is_rotmat_9(rotation): + rotmat = rotmat_as_9(rotmat) + return rotmat + + +def skew_symmetric_to_rotvec(skew_symmetric: torch.Tensor) -> torch.Tensor: + """ + Convert a skew-symmetric matrix to a rotation vector in a differentiable way + [ + [ 0, -z, y], + [ z, 0, -x], + [-y, x, 0], + ] + Args: + skew_symmetric: Skew-symmetric matrix of shape [..., 3, 3] + Returns: + torch.Tensor of shape [..., 3] + """ + assert is_rotmat(skew_symmetric), skew_symmetric.shape + rotvec = torch.stack( + ( + skew_symmetric[..., 2, 1] - skew_symmetric[..., 1, 2], + skew_symmetric[..., 0, 2] - skew_symmetric[..., 2, 0], + skew_symmetric[..., 1, 0] - skew_symmetric[..., 0, 1], + ), + dim=-1, + ) + rotvec = rotvec / 2.0 + return rotvec + + +def integrate_rotmat( + rt: torch.Tensor, + dr_dt: torch.Tensor, + dt: float | torch.Tensor, + body_frame: bool = True, +) -> torch.Tensor: + """ + Integrate a rotation matrix `rt` by the derivative `dr_dt` over the time interval `dt`. + Args: + rt: Rotation matrix, shape [..., 3, 3] + dr_dt: Derivative of the rotation matrix, shape [..., 3, 3] + dt: Time interval to integrate over, scalar or a tensor of shape () or [..., 1] + body_frame: If True, the integration is done in the body frame (post-multiply), + otherwise in the inertial frame (pre-multiply). + Returns: + Integrated unit quaternion, shape [..., 4] + """ + assert rt.shape == dr_dt.shape, f"{rt.shape} != {dr_dt.shape}" + assert is_rotmat(rt), f"{rt.shape} not a rotation matrix" + is_3x3 = is_rotmat_3x3(rt) + if not is_3x3: + rt = rotmat_as_3x3(rt) + dr_dt = rotmat_as_3x3(dr_dt) + if isinstance(dt, torch.Tensor): + assert dt.ndim in ( + 0, + rt.ndim, + rt.ndim - 1, + ), f"dt.ndim = {dt.ndim} | {rt.ndim} | {rt.ndim - 1}" + if dt.ndim == rt.ndim: + assert dt.shape[-2:] == (1, 1), dt.shape + dt = dt.squeeze(-1) + if body_frame: + omega = skew_symmetric_to_rotvec(rotmat_inverse(rt) @ dr_dt) + else: + omega = skew_symmetric_to_rotvec(dr_dt @ rotmat_inverse(rt)) + dr = roma.rotvec_to_rotmat(omega * dt) + if body_frame: + rt = rt @ dr + else: + rt = dr @ rt + if not is_3x3: + rt = rotmat_as_9(rt) + return rt + + +def integrate_rotation( + rt: torch.Tensor, + dr_dt: torch.Tensor, + dt: float | torch.Tensor, + body_frame: bool = True, + half_cover: bool = True, +) -> torch.Tensor: + """ + Integrate the rotation `rt` by the derivative `dr_dt` over the time interval `dt` on the SO(3) manifold. + """ + if is_quaternion(rt): + return integrate_unitquat(rt, dr_dt, dt, body_frame=body_frame, half_cover=half_cover) + if is_rotmat(rt): + return integrate_rotmat(rt, dr_dt, dt, body_frame=body_frame) + raise NotImplementedError(f"integrate_rotation not yet implemented for format {rt.shape}") + + +class PiZeroFlowMatchingModule(ConfigurableModule): + def __init__(self, config: PiZeroFlowMatchingModuleConfig, control_tokenizer: EmptyTokenizer): + super().__init__(config) + del control_tokenizer + self.noised_control_proj = NoisedControlProjector(self.config.noised_control_proj_config) + self.robot_state_proj = RobotStateProjector(self.config.robot_state_proj_config) + self.control_decoder = PiZeroFlowMatchingDecoder(config=self.config.control_decoder_config) + self.output_proj = make_mlp( + [self.config.token_size, 3 + self.config.rotation_components + 1], + activation=torch.nn.GELU, + activate_final=False, + ) + + def forward( + self, + vlm_input: RoboticsFlowInput, + vlm_output: VLMOutput, + cache: Optional[transformers.Cache] = None, + ) -> RoboticsOutput: + robot_state_tokens = self.robot_state_proj(vlm_input) + noised_tokens = self.noised_control_proj(vlm_input.flow_input) + output_tokens = self.control_decoder( + control_tokens=noised_tokens, + robot_state_tokens=robot_state_tokens, + llm_kv_tokens=vlm_output.llm_output.past_key_values, + attn_mask=vlm_input.attn_mask, + cache=cache, + ) + contols = self.output_proj(output_tokens) + (translation, rotation, gripper) = torch.split( + contols, [3, self.config.rotation_components, 1], dim=-1 + ) + return RoboticsOutput.make_empty().replace( + translation=translation, rotation=rotation, gripper=gripper + ) + + @torch.inference_mode() + def generate( + self, + vlm_input: RoboticsFlowInput, + vlm_output: VLMOutput, + processor: PiZeroFlowMatchingProcessor, + use_cache: bool = True, + **kwargs, + ) -> RoboticsOutput: + del kwargs + (batch_size, vlm_seq_len) = vlm_input.input_ids.shape[:2] + device = vlm_input.input_ids.device + if use_cache: + max_cache_len = ( + vlm_seq_len + + processor.config.control_io_config.future_controls_sequence_length + + processor.config.control_io_config.past_scalars_sequence_length + ) + cache = transformers.StaticCache( + config=transformers.PretrainedConfig( + head_dim=self.config.control_decoder_config.block_config.head_dim, + num_key_value_heads=self.config.control_decoder_config.block_config.num_kv_heads, + num_hidden_layers=self.config.control_decoder_config.num_blocks, + ), + max_batch_size=batch_size, + max_cache_len=max_cache_len, + device=device, + ) + else: + cache = None + flow_input: FlowInput = processor.sample_t0_input(batch_size=batch_size, device=device) + step_size = 1 / processor.config.num_inference_steps + translation = flow_input.translation_t0 + rotation = flow_input.rotation_t0 + gripper = flow_input.gripper_t0 + vlm_input = vlm_input.replace( + **{ + "flow_input.timestep": flow_input.timestep, + "flow_input.translation_t": translation, + "flow_input.rotation_t": rotation, + "flow_input.gripper_t": gripper, + } + ) + for _ in range(processor.config.num_inference_steps): + model_output: RoboticsOutput = self(vlm_input, vlm_output, cache) + translation = translation + step_size * model_output.translation + rotation = integrate_rotation(rt=rotation, dr_dt=model_output.rotation, dt=step_size) + gripper = gripper + step_size * model_output.gripper + timestep = vlm_input.flow_input.timestep + step_size + if processor.config.rotation_format == RotationFormat.QUATERNION: + rotation = quaternion_half_cover(rotation) + vlm_input = vlm_input.replace( + **{ + "flow_input.timestep": timestep, + "flow_input.translation_t": translation, + "flow_input.rotation_t": rotation, + "flow_input.gripper_t": gripper, + } + ) + output = RoboticsOutput.make_empty().replace( + translation=translation, rotation=rotation, gripper=gripper + ) + return output + + @property + def fsdp_wrap_modules(self) -> Set[torch.nn.Module]: + return self.control_decoder.fsdp_wrap_modules | { + self, + self.robot_state_proj, + self.noised_control_proj, + self.output_proj, + } + + +CANONICAL_TO_BRIDGE_ROTATION = np.array( + [ + [1, 0, 0], + [0, np.cos(np.pi), -np.sin(np.pi)], + [0, np.sin(np.pi), np.cos(np.pi)], + ], + dtype=np.float32, +) + + +class SPEAR1(ConfigurableModule, transformers.PreTrainedModel): + config_class: transformers.PretrainedConfig = SPEAR1Config + + def __init__(self, config: SPEAR1Config): + super().__init__(config) + self.vlm = PaliGemmaVLM(config=self.config.vlm_config) + self.processor = PiZeroFlowMatchingProcessor( + config=self.config.processor_config, vlm_processor=self.vlm.processor + ) + self.control_module = PiZeroFlowMatchingModule( + config=self.config.control_module_config, + control_tokenizer=self.processor.control_tokenizer, + ) + self.generation_config = transformers.GenerationConfig() + + def forward( + self, + inputs: RoboticsInput, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = None, + ) -> RoboticsOutput: + del output_hidden_states + vlm_output = self.vlm(inputs=inputs, use_cache=use_cache, output_hidden_states=True) + control_output = self.control_module(vlm_input=inputs, vlm_output=vlm_output) + output = control_output.replace(llm_output=vlm_output.llm_output) + return output + + @torch.inference_mode() + def generate( + self, + inputs: RoboticsInput, + use_cache: Optional[bool] = True, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> RoboticsOutput: + del output_hidden_states + vlm_output = self.vlm( + inputs=inputs, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=True, + ) + control_output = self.control_module.generate( + vlm_input=inputs, vlm_output=vlm_output, processor=self.processor + ) + output = control_output.replace(llm_output=vlm_output.llm_output) + return output + + def predict_action(self, inputs: Dict) -> Dict[str, np.ndarray]: + images = inputs["images"] + ee_translation = inputs["ee_translation"] + ee_rotation = inputs["ee_rotation"] + gripper = inputs["gripper"] + + num_resize_args = len(inspect.signature(self.processor.resize_image).parameters) + # Resize images using the processor's resize_image method + for camera_name, camera_image in images.items(): + # Handle the different signatures resize_image - old one used to take only the image, + # new one also takes the camera name + if num_resize_args == 1: + images[camera_name] = self.processor.resize_image(camera_image) + elif num_resize_args == 2: + images[camera_name] = self.processor.resize_image(camera_name, camera_image) + else: + raise ValueError(f"Unexpected number of arguments for resize_image: {num_resize_args}") + # add batch dimension and wrap into list to match processor expected format + images[camera_name] = [images[camera_name]] + + # add batch dimensions to state obs + ee_translation = np.array(ee_translation, dtype=np.float32).reshape(1, 3) + ee_rotation = np.array(ee_rotation, dtype=np.float32).reshape(1, 3, 3) @ CANONICAL_TO_BRIDGE_ROTATION + gripper = np.array(gripper, dtype=np.float32).reshape(1, 1) + joints = np.zeros((1, 7), dtype=np.float32) + + dataset_name = np.array([inputs["dataset_name"]]) + + chat = [f"{inputs['language_instruction']}", ""] + + model_input = self.processor.create_input( + images=images, + chat=chat, + ee_pose_translation=ee_translation, + ee_pose_rotation=ee_rotation, + gripper=gripper, + dataset_name=dataset_name, + joints=joints, + inference_mode=True, + ) + + model_input = model_input.apply( + lambda x: x.unsqueeze(0).to("cuda") if isinstance(x, torch.Tensor) else x + ) + + with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16): + model_output = self.generate(model_input) + + control_plan = self.processor.policy_control_plan_from_model_output( + model_output=model_output, + dataset_name=dataset_name, + valid_mask=torch.ones( + model_output.gripper.shape[:2], dtype=torch.bool, device=model_output.gripper.device + ), + ) + translation_m = control_plan.translation_m.to(dtype=torch.float32, device='cpu') + rotation = control_plan.rotmat.to(dtype=torch.float32, device='cpu') + gripper_prob = control_plan.gripper_prob.to(dtype=torch.float32, device='cpu') + + # Convert controls back to robot base frame + if self.processor.config.eef_control_frame: + # Get the robot base rotation matrix R_BE - the same as the robot EEF pose. + # R_BE - converts from end-effector frame E to robot base frame B + robot_base_rotmat = rotmat_as_3x3(model_input.ee_pose_rotation[:, -1:, ...]).cpu() # [B, 1, 3, 3] + translation_m = torch.matmul( # [B, num_future_control_steps, 3] + robot_base_rotmat, translation_m.unsqueeze(-1) + ).squeeze(-1) + rotation = rotmat_as_3x3( # [B, num_future_control_steps, 3, 3] + torch.matmul(robot_base_rotmat, rotmat_as_3x3(rotation)) + ) + + translation = translation_m # [B, num_future_control_steps, 3] + rotation = rotmat_as_3x3(rotation) # [B, num_future_control_steps, 3, 3] + gripper = gripper_prob # [B, num_future_control_steps, 1] + + translation = translation.squeeze(0).numpy() + rotation = rotation.squeeze(0).numpy() + gripper = gripper.squeeze(0).numpy() + + rotation = CANONICAL_TO_BRIDGE_ROTATION @ rotation @ CANONICAL_TO_BRIDGE_ROTATION.T + + return { + "translation": translation, + "rotation": rotation, + "gripper": gripper, + } + + @property + def fsdp_wrap_modules(self) -> Set[torch.nn.Module]: + return ( + {self.vlm, self.control_module} + | self.vlm.fsdp_wrap_modules + | self.control_module.fsdp_wrap_modules + )