|
import contextlib |
|
import math |
|
|
|
import einops |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
from transformers import Qwen2ForCausalLM, SiglipVisionModel |
|
from transformers.generation.utils import GenerationMixin |
|
from transformers.modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast |
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
from .configuration_nvila_lite import NVILALiteConfig |
|
|
|
MM_HIDDEN_SIZE = 1152 |
|
|
|
|
|
class NVILALiteMultiModalProjectorDownsampleBlock(nn.Module): |
|
def forward(self, x: Tensor) -> Tensor: |
|
batch_size, sequence_length, hidden_size = x.shape |
|
|
|
feat_size = math.isqrt(sequence_length) |
|
|
|
features = x.reshape(batch_size, feat_size, feat_size, hidden_size) |
|
|
|
pad_after = (3 - feat_size % 3) % 3 |
|
if pad_after > 0: |
|
features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after)) |
|
feat_size = feat_size + pad_after |
|
|
|
features = features.reshape(batch_size, feat_size // 3, 3, feat_size // 3, 3, hidden_size) |
|
features = features.permute(0, 1, 3, 2, 4, 5).contiguous() |
|
features = features.reshape(batch_size, -1, 9 * hidden_size) |
|
|
|
return features |
|
|
|
|
|
class NVILALiteMultiModalProjector(nn.Module): |
|
def __init__(self, config: NVILALiteConfig): |
|
super().__init__() |
|
|
|
self.layers = nn.Sequential( |
|
NVILALiteMultiModalProjectorDownsampleBlock(), |
|
nn.LayerNorm(MM_HIDDEN_SIZE * 9), |
|
nn.Linear(MM_HIDDEN_SIZE * 9, MM_HIDDEN_SIZE * 3), |
|
nn.GELU(), |
|
nn.LayerNorm(MM_HIDDEN_SIZE * 3), |
|
nn.Linear(MM_HIDDEN_SIZE * 3, config.text_config.hidden_size), |
|
nn.GELU(), |
|
nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size), |
|
) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
return self.layers(x) |
|
|
|
|
|
class NVILALiteForConditionalGeneration(PreTrainedModel, GenerationMixin): |
|
config_class = NVILALiteConfig |
|
base_model_prefix = "llm" |
|
_auto_class = "AutoModel" |
|
_supports_flash_attn = True |
|
_supports_sdpa = True |
|
|
|
def __init__(self, config: NVILALiteConfig): |
|
super().__init__(config) |
|
|
|
self.config: NVILALiteConfig |
|
|
|
@contextlib.contextmanager |
|
def default_torch_dtype(dtype): |
|
original_dtype = torch.get_default_dtype() |
|
torch.set_default_dtype(dtype) |
|
try: |
|
yield |
|
finally: |
|
torch.set_default_dtype(original_dtype) |
|
|
|
with default_torch_dtype(config.torch_dtype): |
|
self.vision_tower = SiglipVisionModel(config.vision_config) |
|
self.mm_projector = NVILALiteMultiModalProjector(config) |
|
self.llm = Qwen2ForCausalLM(config.text_config) |
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
*, |
|
input_ids: Tensor | None = None, |
|
inputs_embeds: Tensor | None = None, |
|
pixel_values: Tensor | None = None, |
|
pixel_values_videos: Tensor | None = None, |
|
**kwargs, |
|
) -> CausalLMOutputWithPast: |
|
assert (input_ids is None) != ( |
|
inputs_embeds is None |
|
), "Exactly one of `input_ids` or `inputs_embeds` must be specified." |
|
|
|
if input_ids is not None and torch.any( |
|
torch.isin( |
|
input_ids, |
|
torch.tensor( |
|
[self.config.image_token_id, self.config.video_token_id], |
|
device=input_ids.device, |
|
), |
|
).any() |
|
): |
|
inputs_embeds = self._embed( |
|
input_ids=input_ids, |
|
pixel_values=pixel_values, |
|
pixel_values_videos=pixel_values_videos, |
|
) |
|
input_ids = None |
|
|
|
outputs = self.llm( |
|
input_ids=input_ids, |
|
inputs_embeds=inputs_embeds, |
|
**kwargs, |
|
) |
|
|
|
return outputs |
|
|
|
def _embed( |
|
self, |
|
*, |
|
input_ids: Tensor, |
|
pixel_values: Tensor | None, |
|
pixel_values_videos: Tensor | None, |
|
) -> Tensor: |
|
inputs_embeds: Tensor = self.llm.model.embed_tokens(input_ids) |
|
|
|
for pixel_values, media_token_id in [ |
|
(pixel_values, self.config.image_token_id), |
|
(pixel_values_videos, self.config.video_token_id), |
|
]: |
|
if pixel_values is None: |
|
continue |
|
|
|
vision_features = self._encode_vision(pixel_values) |
|
vision_features = einops.rearrange(vision_features, "n p d -> (n p) d") |
|
|
|
inputs_embeds[input_ids == media_token_id] = vision_features |
|
|
|
return inputs_embeds |
|
|
|
def _encode_vision(self, pixel_values: Tensor) -> Tensor: |
|
vision_tower_output: BaseModelOutputWithPooling = self.vision_tower( |
|
pixel_values, |
|
output_hidden_states=True, |
|
) |
|
assert vision_tower_output.hidden_states is not None |
|
|
|
vision_features = vision_tower_output.hidden_states[-2] |
|
|
|
vision_features = self.mm_projector(vision_features) |
|
|
|
return vision_features |
|
|