NVILA-Lite-15B-hf-0904 / modeling_nvila_lite.py
Ligeng-Zhu's picture
Upload files with `vila-upload`.
e8b7c45 verified
import contextlib
import math
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from transformers import Qwen2ForCausalLM, SiglipVisionModel
from transformers.generation.utils import GenerationMixin
from transformers.modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from .configuration_nvila_lite import NVILALiteConfig
MM_HIDDEN_SIZE = 1152
class NVILALiteMultiModalProjectorDownsampleBlock(nn.Module):
def forward(self, x: Tensor) -> Tensor:
batch_size, sequence_length, hidden_size = x.shape
feat_size = math.isqrt(sequence_length)
features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
pad_after = (3 - feat_size % 3) % 3
if pad_after > 0:
features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
feat_size = feat_size + pad_after
features = features.reshape(batch_size, feat_size // 3, 3, feat_size // 3, 3, hidden_size)
features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
features = features.reshape(batch_size, -1, 9 * hidden_size)
return features
class NVILALiteMultiModalProjector(nn.Module):
def __init__(self, config: NVILALiteConfig):
super().__init__()
self.layers = nn.Sequential(
NVILALiteMultiModalProjectorDownsampleBlock(),
nn.LayerNorm(MM_HIDDEN_SIZE * 9),
nn.Linear(MM_HIDDEN_SIZE * 9, MM_HIDDEN_SIZE * 3),
nn.GELU(),
nn.LayerNorm(MM_HIDDEN_SIZE * 3),
nn.Linear(MM_HIDDEN_SIZE * 3, config.text_config.hidden_size),
nn.GELU(),
nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size),
)
def forward(self, x: Tensor) -> Tensor:
return self.layers(x)
class NVILALiteForConditionalGeneration(PreTrainedModel, GenerationMixin):
config_class = NVILALiteConfig
base_model_prefix = "llm"
_auto_class = "AutoModel"
_supports_flash_attn = True
_supports_sdpa = True
def __init__(self, config: NVILALiteConfig):
super().__init__(config)
self.config: NVILALiteConfig
@contextlib.contextmanager
def default_torch_dtype(dtype):
original_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
try:
yield
finally:
torch.set_default_dtype(original_dtype)
with default_torch_dtype(config.torch_dtype):
self.vision_tower = SiglipVisionModel(config.vision_config)
self.mm_projector = NVILALiteMultiModalProjector(config)
self.llm = Qwen2ForCausalLM(config.text_config)
self.post_init()
def forward(
self,
*,
input_ids: Tensor | None = None,
inputs_embeds: Tensor | None = None,
pixel_values: Tensor | None = None,
pixel_values_videos: Tensor | None = None,
**kwargs,
) -> CausalLMOutputWithPast:
assert (input_ids is None) != (
inputs_embeds is None
), "Exactly one of `input_ids` or `inputs_embeds` must be specified."
if input_ids is not None and torch.any(
torch.isin(
input_ids,
torch.tensor(
[self.config.image_token_id, self.config.video_token_id],
device=input_ids.device,
),
).any()
): # Prefill
inputs_embeds = self._embed(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
)
input_ids = None
outputs = self.llm(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)
return outputs
def _embed(
self,
*,
input_ids: Tensor,
pixel_values: Tensor | None,
pixel_values_videos: Tensor | None,
) -> Tensor:
inputs_embeds: Tensor = self.llm.model.embed_tokens(input_ids)
for pixel_values, media_token_id in [
(pixel_values, self.config.image_token_id),
(pixel_values_videos, self.config.video_token_id),
]:
if pixel_values is None:
continue
vision_features = self._encode_vision(pixel_values)
vision_features = einops.rearrange(vision_features, "n p d -> (n p) d")
inputs_embeds[input_ids == media_token_id] = vision_features
return inputs_embeds
def _encode_vision(self, pixel_values: Tensor) -> Tensor:
vision_tower_output: BaseModelOutputWithPooling = self.vision_tower(
pixel_values,
output_hidden_states=True,
)
assert vision_tower_output.hidden_states is not None
vision_features = vision_tower_output.hidden_states[-2]
vision_features = self.mm_projector(vision_features)
return vision_features