# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Misc functions and modules for Cosmos-Embed1.""" import functools from logging import getLogger from typing import Callable, Optional, Protocol import torch import torch.distributed as dist import torch.nn as nn logger = getLogger(__file__) def get_rank(group: Optional[dist.ProcessGroup] = None) -> int: """Get the rank (GPU device) of the worker. Returns: rank (int): The rank of the worker. """ rank = 0 if dist.is_available() and dist.is_initialized(): rank = dist.get_rank(group) return rank def barrier() -> None: """Barrier for all GPUs.""" if dist.is_available() and dist.is_initialized(): dist.barrier() def rank0_first(func: Callable) -> Callable: """Run the function on rank 0 first, then on other ranks.""" @functools.wraps(func) def wrapper(*args, **kwargs): # noqa: ANN202 if get_rank() == 0: result = func(*args, **kwargs) barrier() if get_rank() != 0: result = func(*args, **kwargs) return result return wrapper def add_docstring(docstring: str): def decorator(func): func.__doc__ = docstring return func return decorator INIT_DOCSTRING = """ Constructor for encoding module. Args: embed_dim: size of embedding vectors, e.g. x.shape[3]. max_len: maximum length of temporal sequence, e.g. x.shape[1]. """ FORWARD_DOCSTRING = """ Forward function. Args: x (`torch.Tensor`): rank 4 tensor to add spatio-temporal encodings to. Returns: `torch.Tensor` of rank 4. """ class EncodingProtocol(Protocol): def __init__(self, embed_dim: int, max_len: int) -> None: pass def forward(self, x: torch.Tensor) -> torch.Tensor: pass def interpolate_temp_pos_embed(temp_embed: torch.Tensor, num_frames: int) -> torch.Tensor: """Linearly interpolates temporal encoding from `temp_embed.shape[0] to num_frames.""" temp_embed_resized = temp_embed.permute(1, 0).unsqueeze(0) temp_embed_resized = nn.functional.interpolate( temp_embed_resized, size=(num_frames), mode="linear", align_corners=False, ) return temp_embed_resized.squeeze(0).permute(1, 0) class TemporalParameterEncoding(nn.Module, EncodingProtocol): @add_docstring(INIT_DOCSTRING) def __init__(self, embed_dim: int, max_len: int) -> None: super().__init__() self.embed_dim = embed_dim self.max_len = max_len self.temp_embed = nn.Parameter(torch.zeros(self.max_len, self.embed_dim)) nn.init.trunc_normal_(self.temp_embed, std=0.02) @add_docstring(FORWARD_DOCSTRING) def forward(self, x: torch.Tensor) -> torch.Tensor: _, t, _, _ = x.shape if t != self.temp_embed.shape[0]: logger.debug(f"Interpolating temporal encodings from {self.temp_embed.shape[0]} to {t}.") temp_embed = interpolate_temp_pos_embed(self.temp_embed, t) else: temp_embed = self.temp_embed temp_embed = temp_embed.unsqueeze(0).unsqueeze(2) return x + temp_embed def create_neighbor_weight_matrix(num_tokens: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: indices = torch.arange(num_tokens, dtype=dtype, device=device) abs_diff = torch.abs(indices.unsqueeze(0) - indices.unsqueeze(1)) weights = 1.0 / (2.0**abs_diff) return weights def compute_t_adj(x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: return torch.einsum("bfnd,nk->bfkd", x, weights) def token_propagation(x: torch.Tensor, num_tokens: int) -> torch.Tensor: """Apply neighboring token propagation update.""" weights = create_neighbor_weight_matrix(num_tokens, x.device, x.dtype) t_adj = compute_t_adj(x, weights) return x + t_adj - t_adj.detach() class NeighboringTokenPropagationEncoding(TemporalParameterEncoding): """ Neighboring Token Propagation method inspired by Momentor (https://arxiv.org/abs/2402.11435) """ @add_docstring(FORWARD_DOCSTRING) def forward(self, x: torch.Tensor) -> torch.Tensor: _, t, q, _ = x.shape if t != self.temp_embed.shape[0]: logger.debug(f"Interpolating temporal encodings from {self.temp_embed.shape[0]} to {t}.") temp_embed = interpolate_temp_pos_embed(self.temp_embed, t) else: temp_embed = self.temp_embed temp_embed = temp_embed.unsqueeze(0).unsqueeze(2) if self.training: temp_embed = token_propagation(temp_embed, q) return x + temp_embed class EncodingFactory(nn.Module): def __init__(self, encoding_type: str, embed_dim: int, max_len: int) -> None: super().__init__() fn = { "temporal_parameter": TemporalParameterEncoding, "neighboring_token_propagation": NeighboringTokenPropagationEncoding, }[encoding_type] self.encoding = fn(embed_dim=embed_dim, max_len=max_len) @add_docstring(FORWARD_DOCSTRING) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.encoding(x)