# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Processor class for Cosmos-Embed1 """ from typing import List, Optional, Tuple, Union import numpy as np import torch import torchvision from transformers import AutoProcessor, BatchFeature from transformers.processing_utils import ProcessorMixin from transformers.utils import TensorType from .configuration_embed1 import CosmosEmbed1Config class CosmosEmbed1Processor(ProcessorMixin): r""" Constructs a processor which wraps a BertTokenizer tokenizer and a fast video resize function. Args: tokenizer ([`BertTokenizerFast`], *optional*): The tokenizer is a required input for text processing. config ([`CosmosEmbed1Config`], *optional*): Needed for processing options. """ attributes = ["tokenizer"] tokenizer_class = ("BertTokenizer", "BertTokenizerFast") config_class = CosmosEmbed1Config chat_template = None def __init__( self, tokenizer=None, resolution: Union[int, Tuple[int, int]] = 336, num_video_frames: int = 8, max_txt_len: int = 128, **kwargs, ) -> None: super().__init__(tokenizer, **kwargs) self.resolution = resolution self.num_video_frames = num_video_frames self.max_txt_len = max_txt_len def __call__( self, text: Optional[Union[str, List[str]]] = None, videos: Optional[Union[np.ndarray, torch.Tensor]] = None, return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, resolution: Union[int, Tuple[int, int]] = None, num_video_frames: int = None, max_txt_len: int = None, **kwargs, ) -> BatchFeature: inputs = {} if text is not None: max_txt_len = max_txt_len if max_txt_len is not None else self.max_txt_len tokenized = self.tokenizer( text, return_tensors="pt", padding="max_length", truncation=True, max_length=max_txt_len, **kwargs ) inputs["input_ids"] = tokenized.input_ids inputs["attention_mask"] = tokenized.attention_mask.float() if videos is not None: if isinstance(videos, np.ndarray): videos = torch.from_numpy(videos) if not isinstance(videos, torch.Tensor) or videos.ndim != 5: raise ValueError("Processor expects a numpy or torch tensor of shape BTCHW from [0-255].") resolution = resolution if resolution is not None else self.resolution if isinstance(resolution, int): resolution = (resolution, resolution) _, t, c, h, w = videos.shape if c != 3: raise ValueError(f"Expected tensor of shape BTCHW with RGB channels, got channel size {c}.") num_video_frames = num_video_frames if num_video_frames is not None else self.num_video_frames if t != num_video_frames: raise ValueError(f"Expected tensor of shape BTCHW with {num_video_frames} frames, got {t}.") if h != resolution[0] or w != resolution[1]: videos = resize_video(videos, resolution) if videos.dtype == torch.uint8: videos = videos.float() inputs["videos"] = videos / 255.0 if not inputs: raise ValueError("Must pass either `text` or `videos` argument to __call__ function.") return BatchFeature(inputs, tensor_type=return_tensors) def resize_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: """Resize a video tensor (B, T, C, H, W) to a new height/width. Args: video (torch.Tensor): (B, T, C, H, W) uint8 or float32. size (tuple): target (H', W') size. Returns: torch.Tensor: resized video of shape (B, T, C, H', W') """ h, w = size B, T, C, H, W = video.shape video = video.view(B * T, C, H, W) resize = torchvision.transforms.Resize( (h, w), antialias=True, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, ) video = resize(video) new_H, new_W = video.shape[-2:] video = video.view(B, T, C, new_H, new_W) return video AutoProcessor.register(CosmosEmbed1Config, CosmosEmbed1Processor) __all__ = ["CosmosEmbed1Processor"]