diff --git "a/modeling_instellavl.py" "b/modeling_instellavl.py" new file mode 100644--- /dev/null +++ "b/modeling_instellavl.py" @@ -0,0 +1,2463 @@ +# Modification Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + +from typing import List, Optional, Tuple, Union + +from transformers import (AutoConfig, AutoModelForCausalLM, + OlmoConfig, OlmoModel, OlmoForCausalLM) +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.utils import GenerateOutput +from abc import ABC, abstractmethod + +import re +import os +import math +import random +import shutil +from .mm_utils import get_anyres_image_grid_shape, rank0_print + +from .mm_utils import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN + +import torch +from einops import rearrange, repeat + +try: + from einops_exts import rearrange_many +except: + pass + + + +from torch import einsum + +from torch import Tensor, device +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + +logger = logging.get_logger(__name__) + +########## Projector ############## +class PoolerProjector(nn.Module): + def __init__(self, config, vision_cfg): + super().__init__() + self._config = config + self.hw = vision_cfg.image_size // vision_cfg.patch_size + + self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2) + + self.proj = nn.Sequential( + nn.GELU(), + nn.Linear(config.hidden_size, config.hidden_size), + ) + + def forward(self, x, *args, **kwargs): + height = width = self.hw + assert height * width == x.shape[1] + x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2) + x = self.conv_pool(x) + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + @property + def config(self): + return {"mm_projector_type": "pooler"} + +class IdentityMap(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + @property + def config(self): + return {"mm_projector_type": "identity"} + + +class SimpleResBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.pre_norm = nn.LayerNorm(channels) + + self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)) + + def forward(self, x): + x = self.pre_norm(x) + return x + self.proj(x) + + +def build_vision_projector(config, delay_load=False, **kwargs): + projector_type = getattr(config, "mm_projector_type", "linear") + + if projector_type == "linear": + return nn.Linear(config.mm_hidden_size, config.hidden_size) + + if projector_type == "pooler": + return PoolerProjector(config, kwargs["vision_cfg"]) + + mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) + if mlp_gelu_match: + mlp_depth = int(mlp_gelu_match.group(1)) + modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + return nn.Sequential(*modules) + + mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type) + if mlp_gelu_resnet_match: + mlp_depth = int(mlp_gelu_resnet_match.group(1)) + res_depth = int(mlp_gelu_resnet_match.group(2)) + modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + for _ in range(res_depth): + modules.append(SimpleResBlock(config.hidden_size)) + return nn.Sequential(*modules) + + if projector_type == "identity": + return IdentityMap() + + raise ValueError(f"Unknown projector type: {projector_type}") + +################ Resampler: Spatial Pool #################### +class SpatialPool(nn.Module): + def __init__(self, model_args, vision_tower): + super().__init__() + + self.mode = model_args.mm_spatial_pool_mode + self.stride = model_args.mm_spatial_pool_stride + self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size) + + if self.mode == "average": + self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride) + elif self.mode == "max": + self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride) + elif self.mode == "conv": + self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride) + else: + raise ValueError(f"Unknown pooling mode: {self.pool}.") + + def forward(self, image_features, images, *args, **kwargs): + ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2])) + ori_H = int(ori_W * images.shape[2] // images.shape[3]) + + B, _, F = image_features.shape + + image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2) + image_features_spatial_pool = self.pool(image_features_spatial) + + return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() + + @property + def config(self): + return { + "mm_resampler_type": "spatial_pool", + "mm_spatial_pool_stride": self.stride, + "mm_spatial_pool_mode": self.mode, + "mm_spatial_pool_out_channels": self.out_channels, + } + + @property + def hidden_size(self): + return self.out_channels + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + +############## Qformer #################### +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError("The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0: + self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention) + self.has_cross_attention = True + else: + self.has_cross_attention = False + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BertLayer(config, i) for i in range(config.num_hidden_layers)]) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions, query_length) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(input_shape, attention_mask.shape)) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert query_embeds is not None, "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0 + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class Qformer(nn.Module): + def __init__(self, model_args, vision_tower): + super().__init__() + + self.depth = model_args.mm_qformer_depth + self.num_latents = model_args.mm_qformer_latents + self.pretrained = model_args.mm_qformer_pretrained + + self.Qformer, self.query_tokens, self.ln_vision = self.build_Qformer(vision_tower.hidden_size, self.depth, self.num_latents) + + if self.pretrained is not None: + pretrained_dict = torch.load(self.pretrained, map_location="cpu")["model"] + pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith("t5_proj")} + self.load_state_dict(pretrained_dict) + + def build_Qformer(self, vision_width, cross_attention_freq, num_query_token): + encoder_config = BertConfig.from_pretrained("bert-base-uncased") + encoder_config.encoder_width = vision_width + # insert cross-attention layer every other block + encoder_config.add_cross_attention = True + encoder_config.cross_attention_freq = cross_attention_freq + encoder_config.query_length = num_query_token + Qformer = BertLMHeadModel(config=encoder_config) + query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size)) + query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) + Qformer.cls = None + Qformer.bert.embeddings.word_embeddings = None + Qformer.bert.embeddings.position_embeddings = None + for layer in Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + return Qformer, query_tokens, nn.LayerNorm(vision_width) + + def forward(self, image_features, *args, **kwargs): + x = self.ln_vision(image_features) + image_atts = torch.ones(x.size()[:-1], dtype=torch.long).to(x.device) + + query_tokens = self.query_tokens.expand(x.shape[0], -1, -1) + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=x, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + return query_output.last_hidden_state + + @property + def hidden_size(self): + return 768 + + @property + def config(self): + return { + "mm_resampler_type": "qformer", + "mm_qformer_depth": self.depth, + "mm_qformer_latents": self.num_latents, + "mm_qformer_pretrained": self.pretrained, + } + + +################### Resampler: Perciever ################### +def exists(val): + return val is not None + + +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm_media = nn.LayerNorm(dim) + self.norm_latents = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, T, n1, D) + latent (torch.Tensor): latent features + shape (b, T, n2, D) + """ + x = self.norm_media(x) + latents = self.norm_latents(latents) + + h = self.heads + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) + q = q * self.scale + + # attention + sim = einsum("... i d, ... j d -> ... i j", q, k) + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + out = einsum("... i j, ... j d -> ... i d", attn, v) + out = rearrange(out, "b h t n d -> b t n (h d)", h=h) + return self.to_out(out) + + +class PerceiverResamplerModule(nn.Module): + def __init__( + self, + *, + dim, + depth=6, + dim_head=64, + heads=8, + num_latents=64, + max_num_media=None, + max_num_frames=None, + ff_mult=4, + ): + super().__init__() + self.latents = nn.Parameter(torch.randn(num_latents, dim)) + self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None + self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult) if ff_mult > 0 else nn.Identity(), + ] + ) + ) + + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + """ + Args: + x (torch.Tensor): image features + shape (b, T, F, v, D) + Returns: + shape (b, T, n, D) where n is self.num_latents + """ + b, T, F, v = x.shape[:4] + + # frame and media time embeddings + if exists(self.frame_embs): + frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) + x = x + frame_embs + x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions + if exists(self.media_time_embs): + x = x + self.media_time_embs[:T] + + # blocks + latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + return self.norm(latents) + + +class PerceiverResampler(nn.Module): + def __init__(self, model_args, vision_tower): + super().__init__() + + self.depth = model_args.mm_perceiver_depth + self.num_latents = model_args.mm_perceiver_latents + self.ff_mult = model_args.mm_perceiver_ff_mult + self.pretrained = model_args.mm_perceiver_pretrained + + self.perceiver = PerceiverResamplerModule(dim=vision_tower.hidden_size, depth=self.depth, num_latents=self.num_latents, ff_mult=self.ff_mult) + + if self.pretrained is not None: + self.load_state_dict(torch.load(self.pretrained)) + + def forward(self, image_features, *args, **kwargs): + return self.perceiver(image_features[:, None, None]).squeeze(1) + + @property + def config(self): + return { + "mm_resampler_type": "perceiver", + "mm_perceiver_depth": self.depth, + "mm_perceiver_latents": self.num_latents, + "mm_perceiver_ff_mult": self.ff_mult, + "mm_perceiver_pretrained": self.pretrained, + } + +######################### Resampler: Masker Drop ######################### +class MaskedDrop(nn.Module): + def __init__(self, model_args): + super().__init__() + + self.mode = model_args.mm_mask_drop_mode + self.skip_percentage = model_args.mm_mask_drop_skip_percentage + self.ratio = model_args.mm_mask_drop_ratio + self.ratio_upper = model_args.mm_mask_drop_ratio_upper + self.ratio_lower = model_args.mm_mask_drop_ratio_lower + + def forward(self, image_features, *args, **kwargs): + + if not self.training: + return image_features + + if self.skip_percentage > random.random(): + return image_features + + masked_features = [] + + for image_feature in image_features: + num_tokens = image_feature.shape[0] + if self.mode == "fixed": + num_keep = int(num_tokens * self.ratio) + masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0]) + elif self.mode == "range": + num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper)) + masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0]) + elif self.mode == "cls_only": + masked_features.append(image_feature[0:1]) + else: + raise ValueError(f"Unexpected masked drop mode: {self.mode}") + + if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]): + masked_features = torch.stack(masked_features, dim=0) + + return masked_features + + @property + def config(self): + return { + "mm_resampler_type": "masked_drop", + "mm_mask_drop_mode": self.mode, + "mm_mask_drop_skip_percentage": self.skip_percentage, + "mm_mask_drop_ratio": self.ratio, + "mm_mask_drop_ratio_upper": self.ratio_upper, + "mm_mask_drop_ratio_lower": self.ratio_lower, + } + + def random_masking(self, x, len_keep): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + +class IdentityMap(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + @property + def config(self): + return {"mm_resampler_type": None} + +###################### Resampler - Builder ###################### +def build_vision_resampler(model_args, delay_load=False, **kwargs): + resampler_type = getattr(model_args, "mm_resampler_type", None) + if resampler_type == "masked_drop": + return MaskedDrop(model_args) + elif resampler_type == "spatial_pool": + return SpatialPool(model_args, **kwargs) + elif resampler_type == "perceiver": + return PerceiverResampler(model_args, **kwargs) + elif resampler_type == "qformer": + return Qformer(model_args, **kwargs) + elif resampler_type is None: + return IdentityMap() + + raise ValueError(f"Unknown resampler type: {resampler_type}") + +from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig + +######################## Vision Tower ###################### +class CLIPVisionTower(nn.Module): + r""" + A class to represent the CLIP Vision Tower model. + + Attributes : + ------------ + - is_loaded (bool): A flag indicating whether the model is loaded. + - vision_tower_name (str): The name of the vision tower model. + - select_layer (int): The layer to select features from. + - select_feature (str): The type of feature to select. + + Methods : + ------------ + - `__init__(vision_tower: str, args: Namespace, delay_load: bool = False)`: Initializes the CLIPVisionTower with the given vision tower name and arguments. + - `load_model(device_map: Optional[dict] = None)`: Loads the vision tower model and image processor. + - `feature_select(image_forward_outs: Any) -> torch.Tensor`: Selects features from the image forward outputs based on the specified feature type. + - `forward(images: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor`: Forward pass for the vision tower model. + - `dummy_feature() -> torch.Tensor`: Returns a dummy feature tensor. + - `dtype() -> torch.dtype`: Returns the data type of the vision tower model. + - `device() -> torch.device`: Returns the device of the vision tower model. + - `config() -> Any`: Returns the configuration of the vision tower model. + - `hidden_size() -> int`: Returns the hidden size of the vision tower model. + - `num_patches_per_side() -> int`: Returns the number of patches per side of the image. + - `num_patches() -> int`: Returns the total number of patches in the image. + - `image_size() -> int`: Returns the size of the image. + """ + + def __init__(self, vision_tower, args, delay_load=False): + super().__init__() + + self.is_loaded = False + + self.vision_tower_name = vision_tower + self.select_layer = args.mm_vision_select_layer + self.select_feature = getattr(args, "mm_vision_select_feature", "patch") + + if not delay_load: + rank0_print(f"Loading vision tower: {vision_tower}") + self.load_model() + elif getattr(args, "unfreeze_mm_vision_tower", False): + # TODO: better detector is needed. + rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") + self.load_model() + elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: + rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") + self.load_model() + else: + self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) + + def load_model(self, device_map=None): + if self.is_loaded: + rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name)) + return + + self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) + self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) + self.vision_tower.requires_grad_(False) + + self.is_loaded = True + + def feature_select(self, image_forward_outs): + select_feature_type = self.select_feature + + if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: + select_every_k_layer = len(image_forward_outs.hidden_states) // 4 + image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1) + select_feature_type = select_feature_type.replace("slicefour_", "") + elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]: + select_layers = [-2, -5, -8, -11, 6] + image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1) + select_feature_type = select_feature_type.replace("slice_m25811_f6_", "") + else: + image_features = image_forward_outs.hidden_states[self.select_layer] + + if select_feature_type == "patch": + image_features = image_features[:, 1:] + elif select_feature_type == "cls_patch": + image_features = image_features + else: + raise ValueError(f"Unexpected select feature: {select_feature_type}") + return image_features + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) + image_feature = self.feature_select(image_forward_out).to(image.dtype) + image_features.append(image_feature) + else: + image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + return self.vision_tower.dtype + + @property + def device(self): + return self.vision_tower.device + + @property + def config(self): + if self.is_loaded: + return self.vision_tower.config + else: + return self.cfg_only + + @property + def hidden_size(self): + _hidden_size = self.config.hidden_size + if "slicefour" in self.select_feature: + _hidden_size *= 4 + if "slice_m25811_f6" in self.select_feature: + _hidden_size *= 5 + return _hidden_size + + @property + def num_patches_per_side(self): + return self.config.image_size // self.config.patch_size + + @property + def num_patches(self): + _num_patches = (self.config.image_size // self.config.patch_size) ** 2 + if "cls_patch" in self.select_feature: + _num_patches += 1 + return _num_patches + + @property + def image_size(self): + return self.config.image_size + +def build_vision_tower(vision_tower_cfg, **kwargs): + vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)) + is_absolute_path_exists = os.path.exists(vision_tower) + if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: + return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) + + raise ValueError(f"Unknown vision tower: {vision_tower}") + +class InstellaVLMetaModel: + + def __init__(self, config): + super(InstellaVLMetaModel, self).__init__(config) + + if hasattr(config, "mm_vision_tower"): + delay_load = getattr(config, "delay_load", False) + self.vision_tower = build_vision_tower(config, delay_load=delay_load) + self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower) + self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config) + + if "unpad" in getattr(config, "mm_patch_merge_type", ""): + self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype)) + + def get_vision_tower(self): + vision_tower = getattr(self, "vision_tower", None) + if type(vision_tower) is list: + vision_tower = vision_tower[0] + return vision_tower + + def initialize_vision_modules(self, model_args, fsdp=None): + vision_tower = model_args.vision_tower + mm_vision_select_layer = model_args.mm_vision_select_layer + mm_vision_select_feature = model_args.mm_vision_select_feature + pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter + mm_patch_merge_type = model_args.mm_patch_merge_type + + self.config.mm_vision_tower = vision_tower + self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "") + + if self.get_vision_tower() is None: + vision_tower = build_vision_tower(model_args) + vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower) + for k, v in vision_resampler.config.items(): + setattr(self.config, k, v) + + if fsdp is not None and len(fsdp) > 0: + self.vision_tower = [vision_tower] + self.vision_resampler = [vision_resampler] + else: + self.vision_tower = vision_tower + self.vision_resampler = vision_resampler + else: + if fsdp is not None and len(fsdp) > 0: + vision_resampler = self.vision_resampler[0] + vision_tower = self.vision_tower[0] + else: + vision_resampler = self.vision_resampler + vision_tower = self.vision_tower + vision_tower.load_model() + + # In case it is frozen by LoRA + for p in self.vision_resampler.parameters(): + p.requires_grad = True + + self.config.use_mm_proj = True + self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear") + self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size) + self.config.mm_vision_select_layer = mm_vision_select_layer + self.config.mm_vision_select_feature = mm_vision_select_feature + self.config.mm_patch_merge_type = mm_patch_merge_type + self.config.online_training = model_args.online_training + + if getattr(self, "mm_projector", None) is None: + self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config) + + if "unpad" in mm_patch_merge_type: + embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) + self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std) + else: + # In case it is frozen by LoRA + for p in self.mm_projector.parameters(): + p.requires_grad = True + + if pretrain_mm_mlp_adapter is not None: + mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu") + + def get_w(weights, keyword): + return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k} + + incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector")) + rank0_print(f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}") + incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False) + rank0_print(f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}") + + if 'tmp-' in pretrain_mm_mlp_adapter: + pretrain_mm_mlp_adapter_folder = os.path.dirname(pretrain_mm_mlp_adapter) + shutil.rmtree(pretrain_mm_mlp_adapter_folder, ignore_errors=True) + + + +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. + original_size (tuple): The original size of the image (height, width). + + Returns: + torch.Tensor: The unpadded image tensor. + """ + original_width, original_height = original_size + current_height, current_width = tensor.shape[1:] + + # Compute aspect ratios + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + # Determine padding size and direction + if original_aspect_ratio > current_aspect_ratio: + # Padding was added to the height + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] + else: + # Padding was added to the width + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] + + return unpadded_tensor + + +class InstellaVLMetaForCausalLM(ABC): + + @abstractmethod + def get_model(self): + pass + + def get_vision_tower(self): + return self.get_model().get_vision_tower() + + def get_2dPool(self, image_feature): + height = width = self.get_vision_tower().num_patches_per_side + num_frames, num_tokens, num_dim = image_feature.shape + image_feature = image_feature.view(num_frames, height, width, -1) + image_feature = image_feature.permute(0, 3, 1, 2).contiguous() + # image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride) + if self.config.mm_spatial_pool_mode == "average": + image_feature = nn.functional.avg_pool2d(image_feature, self.config.mm_spatial_pool_stride) + elif self.config.mm_spatial_pool_mode == "max": + image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride) + elif self.config.mm_spatial_pool_mode == "bilinear": + height, weight = image_feature.shape[2:] + scaled_shape = [math.ceil(height / 2), math.ceil(weight / 2)] + image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear') + + else: + raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}") + image_feature = image_feature.permute(0, 2, 3, 1) + image_feature = image_feature.view(num_frames, -1, num_dim) + return image_feature + + def encode_images(self, images): + image_features = self.get_model().get_vision_tower()(images) + # image_features = self.get_model().vision_resampler(image_features, images=images) + image_features = self.get_model().mm_projector(image_features) + return image_features + + def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None): + videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images) + per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0) # tuple, (dim_1, 576, 4096) + all_videos_or_images_features = [] + + for idx, feat in enumerate(per_videos_or_images_features): + feat = self.get_model().mm_projector(feat) + if idx in video_idx_in_batch: + feat = self.get_2dPool(feat) + all_videos_or_images_features.append(feat) + return all_videos_or_images_features + + def add_token_per_grid(self, image_feature): + resize_h = int(math.sqrt(image_feature.shape[1])) + num_frames = image_feature.shape[0] + image_feature = image_feature.view(num_frames, 1, resize_h, resize_h, -1) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + return image_feature + + def add_token_per_frame(self, image_feature): + image_feature = image_feature.permute(2, 0, 1).contiguous() + image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) + image_feature = image_feature.permute(1, 2, 0).contiguous() + return image_feature + + def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None): + vision_tower = self.get_vision_tower() + # rank_print(modalities) + if vision_tower is None or images is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + + if isinstance(modalities, str): + modalities = [modalities] + + if type(images) is list or images.ndim == 5: + if type(images) is list: + images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] + + video_idx_in_batch = [] + for _ in range(len(modalities)): + if modalities[_] == "video": + video_idx_in_batch.append(_) + + # print(video_idx_in_batch) + + images_list = [] + for image in images: + if image.ndim == 4: + images_list.append(image) + else: + images_list.append(image.unsqueeze(0)) + + # import pdb;pdb.set_trace() + concat_images = torch.cat([image for image in images_list], dim=0) + split_sizes = [image.shape[0] for image in images_list] + encoded_image_features = self.encode_images(concat_images) + # import pdb + # pdb.set_trace() + + # This is a list, each element is [num_images, patch * patch, dim] + # rank_print(f"Concat images : {concat_images.shape}") + encoded_image_features = torch.split(encoded_image_features, split_sizes) + image_features = [] + for idx, image_feat in enumerate(encoded_image_features): + if idx in video_idx_in_batch: + image_features.append(self.get_2dPool(image_feat)) + else: + image_features.append(image_feat) + # image_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) + # rank_print(f"Encoded image feats : {[x.shape for x in image_features]}") + # image_features = torch.split(image_features, split_sizes, dim=0) + mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") + image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") + + if mm_patch_merge_type == "flat": + image_features = [x.flatten(0, 1) for x in image_features] + + elif mm_patch_merge_type.startswith("spatial"): + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + # FIXME: now assume the image is square, and split to 2x2 patches + # num_patches = h * w, where h = w = sqrt(num_patches) + # currently image_feature is a tensor of shape (4, num_patches, hidden_size) + # we want to first unflatten it to (2, 2, h, w, hidden_size) + # rank0_print("At least we are reaching here") + if image_idx in video_idx_in_batch: # video operations + # rank0_print("Video") + if self.config.mm_newline_position == "grid": + # Grid-wise + image_feature = self.add_token_per_grid(image_feature) + + new_image_features.append(image_feature) + elif self.config.mm_newline_position == "frame": + # Frame-wise + image_feature = self.add_token_per_frame(image_feature) + + new_image_features.append(image_feature.flatten(0, 1)) + + elif self.config.mm_newline_position == "one_token": + # one-token + image_feature = image_feature.flatten(0, 1) + if 'unpad' in mm_patch_merge_type: + image_feature = torch.cat(( + image_feature, + self.model.image_newline[None].to(image_feature.device) + ), dim=0) + new_image_features.append(image_feature) + elif self.config.mm_newline_position == "no_token": + new_image_features.append(image_feature.flatten(0, 1)) + else: + raise ValueError(f"Unexpected mm_newline_position: {self.config.mm_newline_position}") + + + elif image_feature.shape[0] > 1: # multi patches and multi images operations + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = self.get_vision_tower().num_patches_per_side + + assert height * width == base_image_feature.shape[0] + + if "anyres_max" in image_aspect_ratio: + matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio) + if matched_anyres_max_num_patches: + max_num_patches = int(matched_anyres_max_num_patches.group(1)) + + if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: + if hasattr(self.get_vision_tower(), "image_size"): + vision_tower_image_size = self.get_vision_tower().image_size + else: + raise ValueError("vision_tower_image_size is not found in the vision tower.") + try: + num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size) + except Exception as e: + rank0_print(f"Error: {e}") + num_patch_width, num_patch_height = 2, 2 + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + else: + image_feature = image_feature.view(2, 2, height, width, -1) + + if "maxpool2x2" in mm_patch_merge_type: + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = nn.functional.max_pool2d(image_feature, 2) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches: + unit = image_feature.shape[2] + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + c, h, w = image_feature.shape + times = math.sqrt(h * w / (max_num_patches * unit**2)) + if times > 1.1: + image_feature = image_feature[None] + image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0] + image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + elif "unpad" in mm_patch_merge_type: + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + else: + image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() + image_feature = image_feature.flatten(0, 3) + if "nobase" in mm_patch_merge_type: + pass + else: + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: # single image operations + image_feature = image_feature[0] + if "unpad" in mm_patch_merge_type: + image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0) + + new_image_features.append(image_feature) + image_features = new_image_features + else: + raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") + else: + image_features = self.encode_images(images) + + # TODO: image start / end is not implemented here to support pretraining. + if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False): + raise NotImplementedError + # rank_print(f"Total images : {len(image_features)}") + + # Let's just add dummy tensors if they do not exist, + # it is a headache to deal with None all the time. + # But it is not ideal, and if you have a better idea, + # please open an issue / submit a PR, thanks. + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # remove the padding using attention_mask -- FIXME + _input_ids = input_ids + input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] + labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] + + new_input_embeds = [] + new_labels = [] + cur_image_idx = 0 + # rank_print("Inserting Images embedding") + for batch_idx, cur_input_ids in enumerate(input_ids): + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + # rank0_print(num_images) + if num_images == 0: + try: + cur_image_features = image_features[cur_image_idx] + except IndexError: + try: + cur_image_features = image_features[cur_image_idx - 1] + except IndexError: + pass + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + cur_image_idx += 1 + continue + + image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] + cur_input_ids_noim = [] + cur_labels = labels[batch_idx] + cur_labels_noim = [] + for i in range(len(image_token_indices) - 1): + cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) + cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]) + split_sizes = [x.shape[0] for x in cur_labels_noim] + cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) + cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + cur_new_labels = [] + + for i in range(num_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_im[i]) + cur_new_labels.append(cur_labels_noim[i]) + if i < num_images: + try: + cur_image_features = image_features[cur_image_idx] + except IndexError: + cur_image_features = image_features[cur_image_idx - 1] + cur_image_idx += 1 + cur_new_input_embeds.append(cur_image_features) + cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + + cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + + # import pdb; pdb.set_trace() + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length as image embeddings can make the sequence longer + tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None) + # rank_print("Finishing Inserting") + + new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)] + new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)] + # TODO: Hard code for control loss spike + # if tokenizer_model_max_length is not None: + # new_input_embeds = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)] + # new_labels = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)] + + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) + attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + # rank0_print("Prepare pos id") + + for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, "tokenizer_padding_side", "right") == "left": + new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0)) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + else: + new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + # rank0_print("tokenizer padding") + + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + if getattr(self.config, "use_pos_skipping", False) and self.training: + position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device) + split_position = random.randint(0, new_input_embeds.size(1)) + left_add = random.randint(0, self.config.pos_skipping_range) + right_add = random.randint(left_add, self.config.pos_skipping_range) + position_ids[:, :split_position] += left_add + position_ids[:, split_position:] += right_add + # rank0_print("Finish preparing") + return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels + + def initialize_vision_tokenizer(self, model_args, tokenizer): + if model_args.mm_use_im_patch_token: + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if model_args.mm_use_im_start_end: + num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = self.get_input_embeddings().weight.data + output_embeddings = self.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = True + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + + if model_args.pretrain_mm_mlp_adapter: + mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu") + embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] + assert num_new_tokens == 2 + if input_embeddings.shape == embed_tokens_weight.shape: + input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] + elif embed_tokens_weight.shape[0] == num_new_tokens: + input_embeddings[-num_new_tokens:] = embed_tokens_weight + else: + raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") + elif model_args.mm_use_im_patch_token: + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = False + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + +class InstellaVLConfig(OlmoConfig): + """ + Configuration class for the InstellaVL model. + Attributes: + model_type (str): The type of the model, set to "instellavl". + """ + + model_type = "instellavl" + + +def disable_torch_init(): + r""" + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +class InstellaVLModel(InstellaVLMetaModel, OlmoModel): + config_class = InstellaVLConfig + + def __init__(self, config: OlmoConfig): + super(InstellaVLModel, self).__init__(config) + + +class InstellaVLForCausalLM(OlmoForCausalLM, InstellaVLMetaForCausalLM): + r""" + InstellaVLForCausalLM is a class that extends OlmoForCausalLM and InstellaVLMetaForCausalLM to provide + a language model with multimodal capabilities, specifically for handling images along with text. + + 1. Attributes: + - config_class (type): The configuration class to use for this model. + - model (InstellaVLModel): The underlying model. + - lm_head (nn.Linear): The linear layer for language modeling head. + + 2. Methods: + + 1. `__init__(config: InstellaVLConfig)`: + Initializes the InstellaVLForCausalLM model with the given configuration. + + 2. `get_model() -> InstellaVLModel`: + Returns the underlying model. + + 3. `forward() -> Union[Tuple, CausalLMOutputWithPast]`: + Performs a forward pass through the model. + + 4. `generate() -> Union[GenerateOutput, torch.LongTensor]`: + Generates text based on the input. + + 5. `prepare_inputs_for_generation(input_ids: torch.LongTensor,) -> dict`: + Prepares inputs for text generation. + + """ + + config_class = InstellaVLConfig + + def __init__(self, config: OlmoConfig): + r""" + Initializes the InstellaVLForCausalLM model. + + Args: + - config (OlmoConfig): Configuration object for the model. + + Attributes: + - model (InstellaVLModel): The main model instance. + - lm_head (torch.nn.Linear): Linear layer that maps hidden states to vocabulary size. + """ + super(OlmoForCausalLM, self).__init__(config) + disable_torch_init() + config.model_type = "instellavl" + self.model = InstellaVLModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + + def get_model(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + image_sizes: Optional[List[List[int]]] = None, + return_dict: Optional[bool] = None, + modalities: Optional[List[str]] = ["image"], + cache_position=None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + - input_ids (torch.LongTensor, optional): Input token IDs. + - attention_mask (torch.Tensor, optional): Attention mask. + - position_ids (torch.LongTensor, optional): Position IDs. + - past_key_values (List[torch.FloatTensor], optional): Past key values for caching. + - inputs_embeds (torch.FloatTensor, optional): Input embeddings. + - labels (torch.LongTensor, optional): Labels for language modeling. + - use_cache (bool, optional): Whether to use cache. + - output_attentions (bool, optional): Whether to output attentions. + - output_hidden_states (bool, optional): Whether to output hidden states. + - images (torch.FloatTensor, optional): Input images. + - image_sizes (List[List[int]], optional): Sizes of input images. + - return_dict (bool, optional): Whether to return a dictionary. + - modalities (List[str], optional): List of modalities. + - cache_position (optional): Cache position. + + Returns: + Union[Tuple, CausalLMOutputWithPast]: The output of the forward pass. + """ + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels + ) = self.prepare_inputs_labels_for_multimodal( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + images, + modalities, + image_sizes + ) + + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + images: Optional[torch.Tensor] = None, + image_sizes: Optional[torch.Tensor] = None, + modalities: Optional[List[str]] = ["image"], + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + r""" + Args: + - inputs (torch.Tensor, optional): Input tensor. + - images (torch.Tensor, optional): Input images. + - image_sizes (torch.Tensor, optional): Sizes of input images. + - modalities (List[str], optional): List of modalities. + - **kwargs: Additional arguments. + + Returns: + Union[GenerateOutput, torch.LongTensor]: The generated text. + """ + modalities = kwargs.pop("modalities", None) if "modalities" in kwargs and modalities is None else modalities + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("`inputs_embeds` is not supported") + + if images is not None: + ( + inputs, + position_ids, + attention_mask, + _, + inputs_embeds, + _ + ) = self.prepare_inputs_labels_for_multimodal( + inputs, + position_ids, + attention_mask, + None, + None, + images, + image_sizes=image_sizes + ) + else: + inputs_embeds = self.get_model().embed_tokens(inputs) + return super().generate( + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, + inputs_embeds=None, **kwargs): + r""" + Args: + - input_ids (torch.LongTensor): Input token IDs. + - past_key_values (List[torch.FloatTensor], optional): Past key values for caching. + - inputs_embeds (torch.FloatTensor, optional): Input embeddings. + - **kwargs: Additional arguments. + + Returns: + dict: Prepared inputs for generation. + """ + images = kwargs.pop("images", None) + image_sizes = kwargs.pop("image_sizes", None) + inputs = super().prepare_inputs_for_generation( + input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs + ) + if images is not None: + inputs['images'] = images + if image_sizes is not None: + inputs['image_sizes'] = image_sizes + return inputs + +AutoConfig.register("instellavl", InstellaVLConfig) +AutoModelForCausalLM.register(InstellaVLConfig, InstellaVLForCausalLM)