|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
from transformers import (AutoConfig, AutoModelForCausalLM, |
|
OlmoConfig, OlmoModel, OlmoForCausalLM) |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from transformers.generation.utils import GenerateOutput |
|
from abc import ABC, abstractmethod |
|
|
|
import re |
|
import os |
|
import math |
|
import random |
|
import shutil |
|
from .mm_utils import get_anyres_image_grid_shape, rank0_print |
|
|
|
from .mm_utils import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN |
|
|
|
import torch |
|
from einops import rearrange, repeat |
|
|
|
try: |
|
from einops_exts import rearrange_many |
|
except: |
|
pass |
|
|
|
|
|
|
|
from torch import einsum |
|
|
|
from torch import Tensor, device |
|
import torch.utils.checkpoint |
|
from torch.nn import CrossEntropyLoss |
|
|
|
from transformers.activations import ACT2FN |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithPastAndCrossAttentions, |
|
BaseModelOutputWithPoolingAndCrossAttentions, |
|
CausalLMOutputWithCrossAttentions, |
|
MaskedLMOutput, |
|
) |
|
from transformers.modeling_utils import ( |
|
PreTrainedModel, |
|
apply_chunking_to_forward, |
|
find_pruneable_heads_and_indices, |
|
prune_linear_layer, |
|
) |
|
from transformers.utils import logging |
|
from transformers.models.bert.configuration_bert import BertConfig |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class PoolerProjector(nn.Module): |
|
def __init__(self, config, vision_cfg): |
|
super().__init__() |
|
self._config = config |
|
self.hw = vision_cfg.image_size // vision_cfg.patch_size |
|
|
|
self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2) |
|
|
|
self.proj = nn.Sequential( |
|
nn.GELU(), |
|
nn.Linear(config.hidden_size, config.hidden_size), |
|
) |
|
|
|
def forward(self, x, *args, **kwargs): |
|
height = width = self.hw |
|
assert height * width == x.shape[1] |
|
x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2) |
|
x = self.conv_pool(x) |
|
x = x.flatten(2).transpose(1, 2) |
|
x = self.proj(x) |
|
return x |
|
|
|
@property |
|
def config(self): |
|
return {"mm_projector_type": "pooler"} |
|
|
|
class IdentityMap(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x, *args, **kwargs): |
|
return x |
|
|
|
@property |
|
def config(self): |
|
return {"mm_projector_type": "identity"} |
|
|
|
|
|
class SimpleResBlock(nn.Module): |
|
def __init__(self, channels): |
|
super().__init__() |
|
self.pre_norm = nn.LayerNorm(channels) |
|
|
|
self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)) |
|
|
|
def forward(self, x): |
|
x = self.pre_norm(x) |
|
return x + self.proj(x) |
|
|
|
|
|
def build_vision_projector(config, delay_load=False, **kwargs): |
|
projector_type = getattr(config, "mm_projector_type", "linear") |
|
|
|
if projector_type == "linear": |
|
return nn.Linear(config.mm_hidden_size, config.hidden_size) |
|
|
|
if projector_type == "pooler": |
|
return PoolerProjector(config, kwargs["vision_cfg"]) |
|
|
|
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) |
|
if mlp_gelu_match: |
|
mlp_depth = int(mlp_gelu_match.group(1)) |
|
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] |
|
for _ in range(1, mlp_depth): |
|
modules.append(nn.GELU()) |
|
modules.append(nn.Linear(config.hidden_size, config.hidden_size)) |
|
return nn.Sequential(*modules) |
|
|
|
mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type) |
|
if mlp_gelu_resnet_match: |
|
mlp_depth = int(mlp_gelu_resnet_match.group(1)) |
|
res_depth = int(mlp_gelu_resnet_match.group(2)) |
|
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] |
|
for _ in range(1, mlp_depth): |
|
modules.append(nn.GELU()) |
|
modules.append(nn.Linear(config.hidden_size, config.hidden_size)) |
|
for _ in range(res_depth): |
|
modules.append(SimpleResBlock(config.hidden_size)) |
|
return nn.Sequential(*modules) |
|
|
|
if projector_type == "identity": |
|
return IdentityMap() |
|
|
|
raise ValueError(f"Unknown projector type: {projector_type}") |
|
|
|
|
|
class SpatialPool(nn.Module): |
|
def __init__(self, model_args, vision_tower): |
|
super().__init__() |
|
|
|
self.mode = model_args.mm_spatial_pool_mode |
|
self.stride = model_args.mm_spatial_pool_stride |
|
self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size) |
|
|
|
if self.mode == "average": |
|
self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride) |
|
elif self.mode == "max": |
|
self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride) |
|
elif self.mode == "conv": |
|
self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride) |
|
else: |
|
raise ValueError(f"Unknown pooling mode: {self.pool}.") |
|
|
|
def forward(self, image_features, images, *args, **kwargs): |
|
ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2])) |
|
ori_H = int(ori_W * images.shape[2] // images.shape[3]) |
|
|
|
B, _, F = image_features.shape |
|
|
|
image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2) |
|
image_features_spatial_pool = self.pool(image_features_spatial) |
|
|
|
return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() |
|
|
|
@property |
|
def config(self): |
|
return { |
|
"mm_resampler_type": "spatial_pool", |
|
"mm_spatial_pool_stride": self.stride, |
|
"mm_spatial_pool_mode": self.mode, |
|
"mm_spatial_pool_out_channels": self.out_channels, |
|
} |
|
|
|
@property |
|
def hidden_size(self): |
|
return self.out_channels |
|
|
|
def disabled_train(self, mode=True): |
|
"""Overwrite model.train with this function to make sure train/eval mode |
|
does not change anymore.""" |
|
return self |
|
|
|
|
|
class BertEmbeddings(nn.Module): |
|
"""Construct the embeddings from word and position embeddings.""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) |
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
|
|
|
|
|
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) |
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") |
|
|
|
self.config = config |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
position_ids=None, |
|
query_embeds=None, |
|
past_key_values_length=0, |
|
): |
|
if input_ids is not None: |
|
seq_length = input_ids.size()[1] |
|
else: |
|
seq_length = 0 |
|
|
|
if position_ids is None: |
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone() |
|
|
|
if input_ids is not None: |
|
embeddings = self.word_embeddings(input_ids) |
|
if self.position_embedding_type == "absolute": |
|
position_embeddings = self.position_embeddings(position_ids) |
|
embeddings = embeddings + position_embeddings |
|
|
|
if query_embeds is not None: |
|
embeddings = torch.cat((query_embeds, embeddings), dim=1) |
|
else: |
|
embeddings = query_embeds |
|
|
|
embeddings = self.LayerNorm(embeddings) |
|
embeddings = self.dropout(embeddings) |
|
return embeddings |
|
|
|
|
|
class BertSelfAttention(nn.Module): |
|
def __init__(self, config, is_cross_attention): |
|
super().__init__() |
|
self.config = config |
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): |
|
raise ValueError("The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads)) |
|
|
|
self.num_attention_heads = config.num_attention_heads |
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) |
|
self.all_head_size = self.num_attention_heads * self.attention_head_size |
|
|
|
self.query = nn.Linear(config.hidden_size, self.all_head_size) |
|
if is_cross_attention: |
|
self.key = nn.Linear(config.encoder_width, self.all_head_size) |
|
self.value = nn.Linear(config.encoder_width, self.all_head_size) |
|
else: |
|
self.key = nn.Linear(config.hidden_size, self.all_head_size) |
|
self.value = nn.Linear(config.hidden_size, self.all_head_size) |
|
|
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") |
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": |
|
self.max_position_embeddings = config.max_position_embeddings |
|
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) |
|
self.save_attention = False |
|
|
|
def save_attn_gradients(self, attn_gradients): |
|
self.attn_gradients = attn_gradients |
|
|
|
def get_attn_gradients(self): |
|
return self.attn_gradients |
|
|
|
def save_attention_map(self, attention_map): |
|
self.attention_map = attention_map |
|
|
|
def get_attention_map(self): |
|
return self.attention_map |
|
|
|
def transpose_for_scores(self, x): |
|
new_x_shape = x.size()[:-1] + ( |
|
self.num_attention_heads, |
|
self.attention_head_size, |
|
) |
|
x = x.view(*new_x_shape) |
|
return x.permute(0, 2, 1, 3) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
head_mask=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
past_key_value=None, |
|
output_attentions=False, |
|
): |
|
|
|
|
|
|
|
|
|
is_cross_attention = encoder_hidden_states is not None |
|
|
|
if is_cross_attention: |
|
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) |
|
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) |
|
attention_mask = encoder_attention_mask |
|
elif past_key_value is not None: |
|
key_layer = self.transpose_for_scores(self.key(hidden_states)) |
|
value_layer = self.transpose_for_scores(self.value(hidden_states)) |
|
key_layer = torch.cat([past_key_value[0], key_layer], dim=2) |
|
value_layer = torch.cat([past_key_value[1], value_layer], dim=2) |
|
else: |
|
key_layer = self.transpose_for_scores(self.key(hidden_states)) |
|
value_layer = self.transpose_for_scores(self.value(hidden_states)) |
|
|
|
mixed_query_layer = self.query(hidden_states) |
|
|
|
query_layer = self.transpose_for_scores(mixed_query_layer) |
|
|
|
past_key_value = (key_layer, value_layer) |
|
|
|
|
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
|
|
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": |
|
seq_length = hidden_states.size()[1] |
|
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) |
|
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) |
|
distance = position_ids_l - position_ids_r |
|
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) |
|
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) |
|
|
|
if self.position_embedding_type == "relative_key": |
|
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
|
attention_scores = attention_scores + relative_position_scores |
|
elif self.position_embedding_type == "relative_key_query": |
|
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
|
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) |
|
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key |
|
|
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
|
if attention_mask is not None: |
|
|
|
attention_scores = attention_scores + attention_mask |
|
|
|
|
|
attention_probs = nn.Softmax(dim=-1)(attention_scores) |
|
|
|
if is_cross_attention and self.save_attention: |
|
self.save_attention_map(attention_probs) |
|
attention_probs.register_hook(self.save_attn_gradients) |
|
|
|
|
|
|
|
attention_probs_dropped = self.dropout(attention_probs) |
|
|
|
|
|
if head_mask is not None: |
|
attention_probs_dropped = attention_probs_dropped * head_mask |
|
|
|
context_layer = torch.matmul(attention_probs_dropped, value_layer) |
|
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
|
context_layer = context_layer.view(*new_context_layer_shape) |
|
|
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) |
|
|
|
outputs = outputs + (past_key_value,) |
|
return outputs |
|
|
|
|
|
class BertSelfOutput(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
def forward(self, hidden_states, input_tensor): |
|
hidden_states = self.dense(hidden_states) |
|
hidden_states = self.dropout(hidden_states) |
|
hidden_states = self.LayerNorm(hidden_states + input_tensor) |
|
return hidden_states |
|
|
|
|
|
class BertAttention(nn.Module): |
|
def __init__(self, config, is_cross_attention=False): |
|
super().__init__() |
|
self.self = BertSelfAttention(config, is_cross_attention) |
|
self.output = BertSelfOutput(config) |
|
self.pruned_heads = set() |
|
|
|
def prune_heads(self, heads): |
|
if len(heads) == 0: |
|
return |
|
heads, index = find_pruneable_heads_and_indices( |
|
heads, |
|
self.self.num_attention_heads, |
|
self.self.attention_head_size, |
|
self.pruned_heads, |
|
) |
|
|
|
|
|
self.self.query = prune_linear_layer(self.self.query, index) |
|
self.self.key = prune_linear_layer(self.self.key, index) |
|
self.self.value = prune_linear_layer(self.self.value, index) |
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) |
|
|
|
|
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads) |
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads |
|
self.pruned_heads = self.pruned_heads.union(heads) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
head_mask=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
past_key_value=None, |
|
output_attentions=False, |
|
): |
|
self_outputs = self.self( |
|
hidden_states, |
|
attention_mask, |
|
head_mask, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
past_key_value, |
|
output_attentions, |
|
) |
|
attention_output = self.output(self_outputs[0], hidden_states) |
|
|
|
outputs = (attention_output,) + self_outputs[1:] |
|
return outputs |
|
|
|
|
|
class BertIntermediate(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) |
|
if isinstance(config.hidden_act, str): |
|
self.intermediate_act_fn = ACT2FN[config.hidden_act] |
|
else: |
|
self.intermediate_act_fn = config.hidden_act |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = self.dense(hidden_states) |
|
hidden_states = self.intermediate_act_fn(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class BertOutput(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) |
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
def forward(self, hidden_states, input_tensor): |
|
hidden_states = self.dense(hidden_states) |
|
hidden_states = self.dropout(hidden_states) |
|
hidden_states = self.LayerNorm(hidden_states + input_tensor) |
|
return hidden_states |
|
|
|
|
|
class BertLayer(nn.Module): |
|
def __init__(self, config, layer_num): |
|
super().__init__() |
|
self.config = config |
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward |
|
self.seq_len_dim = 1 |
|
self.attention = BertAttention(config) |
|
self.layer_num = layer_num |
|
if self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0: |
|
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention) |
|
self.has_cross_attention = True |
|
else: |
|
self.has_cross_attention = False |
|
self.intermediate = BertIntermediate(config) |
|
self.output = BertOutput(config) |
|
|
|
self.intermediate_query = BertIntermediate(config) |
|
self.output_query = BertOutput(config) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
head_mask=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
past_key_value=None, |
|
output_attentions=False, |
|
query_length=0, |
|
): |
|
|
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None |
|
self_attention_outputs = self.attention( |
|
hidden_states, |
|
attention_mask, |
|
head_mask, |
|
output_attentions=output_attentions, |
|
past_key_value=self_attn_past_key_value, |
|
) |
|
attention_output = self_attention_outputs[0] |
|
outputs = self_attention_outputs[1:-1] |
|
|
|
present_key_value = self_attention_outputs[-1] |
|
|
|
if query_length > 0: |
|
query_attention_output = attention_output[:, :query_length, :] |
|
|
|
if self.has_cross_attention: |
|
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" |
|
cross_attention_outputs = self.crossattention( |
|
query_attention_output, |
|
attention_mask, |
|
head_mask, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
output_attentions=output_attentions, |
|
) |
|
query_attention_output = cross_attention_outputs[0] |
|
outputs = outputs + cross_attention_outputs[1:-1] |
|
|
|
layer_output = apply_chunking_to_forward( |
|
self.feed_forward_chunk_query, |
|
self.chunk_size_feed_forward, |
|
self.seq_len_dim, |
|
query_attention_output, |
|
) |
|
if attention_output.shape[1] > query_length: |
|
layer_output_text = apply_chunking_to_forward( |
|
self.feed_forward_chunk, |
|
self.chunk_size_feed_forward, |
|
self.seq_len_dim, |
|
attention_output[:, query_length:, :], |
|
) |
|
layer_output = torch.cat([layer_output, layer_output_text], dim=1) |
|
else: |
|
layer_output = apply_chunking_to_forward( |
|
self.feed_forward_chunk, |
|
self.chunk_size_feed_forward, |
|
self.seq_len_dim, |
|
attention_output, |
|
) |
|
outputs = (layer_output,) + outputs |
|
|
|
outputs = outputs + (present_key_value,) |
|
|
|
return outputs |
|
|
|
def feed_forward_chunk(self, attention_output): |
|
intermediate_output = self.intermediate(attention_output) |
|
layer_output = self.output(intermediate_output, attention_output) |
|
return layer_output |
|
|
|
def feed_forward_chunk_query(self, attention_output): |
|
intermediate_output = self.intermediate_query(attention_output) |
|
layer_output = self.output_query(intermediate_output, attention_output) |
|
return layer_output |
|
|
|
|
|
class BertEncoder(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.layer = nn.ModuleList([BertLayer(config, i) for i in range(config.num_hidden_layers)]) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
head_mask=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
past_key_values=None, |
|
use_cache=None, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
return_dict=True, |
|
query_length=0, |
|
): |
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attentions = () if output_attentions else None |
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None |
|
|
|
next_decoder_cache = () if use_cache else None |
|
|
|
for i in range(self.config.num_hidden_layers): |
|
layer_module = self.layer[i] |
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
layer_head_mask = head_mask[i] if head_mask is not None else None |
|
past_key_value = past_key_values[i] if past_key_values is not None else None |
|
|
|
if getattr(self.config, "gradient_checkpointing", False) and self.training: |
|
|
|
if use_cache: |
|
logger.warn("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") |
|
use_cache = False |
|
|
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
return module(*inputs, past_key_value, output_attentions, query_length) |
|
|
|
return custom_forward |
|
|
|
layer_outputs = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(layer_module), |
|
hidden_states, |
|
attention_mask, |
|
layer_head_mask, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
) |
|
else: |
|
layer_outputs = layer_module( |
|
hidden_states, |
|
attention_mask, |
|
layer_head_mask, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
past_key_value, |
|
output_attentions, |
|
query_length, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
if use_cache: |
|
next_decoder_cache += (layer_outputs[-1],) |
|
if output_attentions: |
|
all_self_attentions = all_self_attentions + (layer_outputs[1],) |
|
all_cross_attentions = all_cross_attentions + (layer_outputs[2],) |
|
|
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
if not return_dict: |
|
return tuple( |
|
v |
|
for v in [ |
|
hidden_states, |
|
next_decoder_cache, |
|
all_hidden_states, |
|
all_self_attentions, |
|
all_cross_attentions, |
|
] |
|
if v is not None |
|
) |
|
return BaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=hidden_states, |
|
past_key_values=next_decoder_cache, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attentions, |
|
cross_attentions=all_cross_attentions, |
|
) |
|
|
|
|
|
class BertPooler(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.activation = nn.Tanh() |
|
|
|
def forward(self, hidden_states): |
|
|
|
|
|
first_token_tensor = hidden_states[:, 0] |
|
pooled_output = self.dense(first_token_tensor) |
|
pooled_output = self.activation(pooled_output) |
|
return pooled_output |
|
|
|
|
|
class BertPredictionHeadTransform(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
if isinstance(config.hidden_act, str): |
|
self.transform_act_fn = ACT2FN[config.hidden_act] |
|
else: |
|
self.transform_act_fn = config.hidden_act |
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = self.dense(hidden_states) |
|
hidden_states = self.transform_act_fn(hidden_states) |
|
hidden_states = self.LayerNorm(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class BertLMPredictionHead(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.transform = BertPredictionHeadTransform(config) |
|
|
|
|
|
|
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) |
|
|
|
|
|
self.decoder.bias = self.bias |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = self.transform(hidden_states) |
|
hidden_states = self.decoder(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class BertOnlyMLMHead(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.predictions = BertLMPredictionHead(config) |
|
|
|
def forward(self, sequence_output): |
|
prediction_scores = self.predictions(sequence_output) |
|
return prediction_scores |
|
|
|
|
|
class BertPreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
|
|
config_class = BertConfig |
|
base_model_prefix = "bert" |
|
_keys_to_ignore_on_load_missing = [r"position_ids"] |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
if isinstance(module, (nn.Linear, nn.Embedding)): |
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
|
|
class BertModel(BertPreTrainedModel): |
|
""" |
|
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of |
|
cross-attention is added between the self-attention layers, following the architecture described in `Attention is |
|
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, |
|
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. |
|
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an |
|
input to the forward pass. |
|
""" |
|
|
|
def __init__(self, config, add_pooling_layer=False): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
self.embeddings = BertEmbeddings(config) |
|
|
|
self.encoder = BertEncoder(config) |
|
|
|
self.pooler = BertPooler(config) if add_pooling_layer else None |
|
|
|
self.init_weights() |
|
|
|
def get_input_embeddings(self): |
|
return self.embeddings.word_embeddings |
|
|
|
def set_input_embeddings(self, value): |
|
self.embeddings.word_embeddings = value |
|
|
|
def _prune_heads(self, heads_to_prune): |
|
""" |
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base |
|
class PreTrainedModel |
|
""" |
|
for layer, heads in heads_to_prune.items(): |
|
self.encoder.layer[layer].attention.prune_heads(heads) |
|
|
|
def get_extended_attention_mask( |
|
self, |
|
attention_mask: Tensor, |
|
input_shape: Tuple[int], |
|
device: device, |
|
is_decoder: bool, |
|
has_query: bool = False, |
|
) -> Tensor: |
|
""" |
|
Makes broadcastable attention and causal masks so that future and masked tokens are ignored. |
|
|
|
Arguments: |
|
attention_mask (:obj:`torch.Tensor`): |
|
Mask with ones indicating tokens to attend to, zeros for tokens to ignore. |
|
input_shape (:obj:`Tuple[int]`): |
|
The shape of the input to the model. |
|
device: (:obj:`torch.device`): |
|
The device of the input to the model. |
|
|
|
Returns: |
|
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. |
|
""" |
|
|
|
|
|
if attention_mask.dim() == 3: |
|
extended_attention_mask = attention_mask[:, None, :, :] |
|
elif attention_mask.dim() == 2: |
|
|
|
|
|
|
|
if is_decoder: |
|
batch_size, seq_length = input_shape |
|
|
|
seq_ids = torch.arange(seq_length, device=device) |
|
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] |
|
|
|
|
|
|
|
causal_mask = causal_mask.to(attention_mask.dtype) |
|
|
|
if causal_mask.shape[1] < attention_mask.shape[1]: |
|
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] |
|
if has_query: |
|
causal_mask = torch.cat( |
|
[ |
|
torch.zeros( |
|
(batch_size, prefix_seq_len, seq_length), |
|
device=device, |
|
dtype=causal_mask.dtype, |
|
), |
|
causal_mask, |
|
], |
|
axis=1, |
|
) |
|
causal_mask = torch.cat( |
|
[ |
|
torch.ones( |
|
(batch_size, causal_mask.shape[1], prefix_seq_len), |
|
device=device, |
|
dtype=causal_mask.dtype, |
|
), |
|
causal_mask, |
|
], |
|
axis=-1, |
|
) |
|
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] |
|
else: |
|
extended_attention_mask = attention_mask[:, None, None, :] |
|
else: |
|
raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(input_shape, attention_mask.shape)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) |
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 |
|
return extended_attention_mask |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
position_ids=None, |
|
head_mask=None, |
|
query_embeds=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
past_key_values=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
is_decoder=False, |
|
): |
|
r""" |
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): |
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if |
|
the model is configured as a decoder. |
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in |
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: |
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): |
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. |
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` |
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` |
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. |
|
use_cache (:obj:`bool`, `optional`): |
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up |
|
decoding (see :obj:`past_key_values`). |
|
""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
if input_ids is None: |
|
assert query_embeds is not None, "You have to specify query_embeds when input_ids is None" |
|
|
|
|
|
past_key_values_length = past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0 |
|
|
|
query_length = query_embeds.shape[1] if query_embeds is not None else 0 |
|
|
|
embedding_output = self.embeddings( |
|
input_ids=input_ids, |
|
position_ids=position_ids, |
|
query_embeds=query_embeds, |
|
past_key_values_length=past_key_values_length, |
|
) |
|
|
|
input_shape = embedding_output.size()[:-1] |
|
batch_size, seq_length = input_shape |
|
device = embedding_output.device |
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) |
|
|
|
|
|
|
|
if is_decoder: |
|
extended_attention_mask = self.get_extended_attention_mask( |
|
attention_mask, |
|
input_ids.shape, |
|
device, |
|
is_decoder, |
|
has_query=(query_embeds is not None), |
|
) |
|
else: |
|
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device, is_decoder) |
|
|
|
|
|
|
|
if encoder_hidden_states is not None: |
|
if type(encoder_hidden_states) == list: |
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() |
|
else: |
|
( |
|
encoder_batch_size, |
|
encoder_sequence_length, |
|
_, |
|
) = encoder_hidden_states.size() |
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) |
|
|
|
if type(encoder_attention_mask) == list: |
|
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] |
|
elif encoder_attention_mask is None: |
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) |
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
|
else: |
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
|
else: |
|
encoder_extended_attention_mask = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
|
|
|
encoder_outputs = self.encoder( |
|
embedding_output, |
|
attention_mask=extended_attention_mask, |
|
head_mask=head_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_extended_attention_mask, |
|
past_key_values=past_key_values, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
query_length=query_length, |
|
) |
|
sequence_output = encoder_outputs[0] |
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None |
|
|
|
if not return_dict: |
|
return (sequence_output, pooled_output) + encoder_outputs[1:] |
|
|
|
return BaseModelOutputWithPoolingAndCrossAttentions( |
|
last_hidden_state=sequence_output, |
|
pooler_output=pooled_output, |
|
past_key_values=encoder_outputs.past_key_values, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
cross_attentions=encoder_outputs.cross_attentions, |
|
) |
|
|
|
|
|
class BertLMHeadModel(BertPreTrainedModel): |
|
|
|
_keys_to_ignore_on_load_unexpected = [r"pooler"] |
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
self.cls = BertOnlyMLMHead(config) |
|
|
|
self.init_weights() |
|
|
|
def get_output_embeddings(self): |
|
return self.cls.predictions.decoder |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.cls.predictions.decoder = new_embeddings |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
position_ids=None, |
|
head_mask=None, |
|
query_embeds=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
labels=None, |
|
past_key_values=None, |
|
use_cache=True, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
return_logits=False, |
|
is_decoder=True, |
|
reduction="mean", |
|
): |
|
r""" |
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): |
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if |
|
the model is configured as a decoder. |
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in |
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: |
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
|
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in |
|
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are |
|
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` |
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): |
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. |
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` |
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` |
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. |
|
use_cache (:obj:`bool`, `optional`): |
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up |
|
decoding (see :obj:`past_key_values`). |
|
Returns: |
|
Example:: |
|
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig |
|
>>> import torch |
|
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') |
|
>>> config = BertConfig.from_pretrained("bert-base-cased") |
|
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) |
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") |
|
>>> outputs = model(**inputs) |
|
>>> prediction_logits = outputs.logits |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
if labels is not None: |
|
use_cache = False |
|
if past_key_values is not None: |
|
query_embeds = None |
|
|
|
outputs = self.bert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
query_embeds=query_embeds, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
past_key_values=past_key_values, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
is_decoder=is_decoder, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
if query_embeds is not None: |
|
sequence_output = outputs[0][:, query_embeds.shape[1] :, :] |
|
|
|
prediction_scores = self.cls(sequence_output) |
|
|
|
if return_logits: |
|
return prediction_scores[:, :-1, :].contiguous() |
|
|
|
lm_loss = None |
|
if labels is not None: |
|
|
|
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() |
|
labels = labels[:, 1:].contiguous() |
|
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) |
|
lm_loss = loss_fct( |
|
shifted_prediction_scores.view(-1, self.config.vocab_size), |
|
labels.view(-1), |
|
) |
|
if reduction == "none": |
|
lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) |
|
|
|
if not return_dict: |
|
output = (prediction_scores,) + outputs[2:] |
|
return ((lm_loss,) + output) if lm_loss is not None else output |
|
|
|
return CausalLMOutputWithCrossAttentions( |
|
loss=lm_loss, |
|
logits=prediction_scores, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
cross_attentions=outputs.cross_attentions, |
|
) |
|
|
|
def prepare_inputs_for_generation(self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs): |
|
|
|
if attention_mask is None: |
|
attention_mask = input_ids.new_ones(input_ids.shape) |
|
query_mask = input_ids.new_ones(query_embeds.shape[:-1]) |
|
attention_mask = torch.cat([query_mask, attention_mask], dim=-1) |
|
|
|
|
|
if past is not None: |
|
input_ids = input_ids[:, -1:] |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"query_embeds": query_embeds, |
|
"attention_mask": attention_mask, |
|
"past_key_values": past, |
|
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), |
|
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), |
|
"is_decoder": True, |
|
} |
|
|
|
def _reorder_cache(self, past, beam_idx): |
|
reordered_past = () |
|
for layer_past in past: |
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) |
|
return reordered_past |
|
|
|
|
|
class BertForMaskedLM(BertPreTrainedModel): |
|
|
|
_keys_to_ignore_on_load_unexpected = [r"pooler"] |
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
self.cls = BertOnlyMLMHead(config) |
|
|
|
self.init_weights() |
|
|
|
def get_output_embeddings(self): |
|
return self.cls.predictions.decoder |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.cls.predictions.decoder = new_embeddings |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
position_ids=None, |
|
head_mask=None, |
|
query_embeds=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
labels=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
return_logits=False, |
|
is_decoder=False, |
|
): |
|
r""" |
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
|
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., |
|
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored |
|
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` |
|
""" |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.bert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
query_embeds=query_embeds, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
is_decoder=is_decoder, |
|
) |
|
|
|
if query_embeds is not None: |
|
sequence_output = outputs[0][:, query_embeds.shape[1] :, :] |
|
prediction_scores = self.cls(sequence_output) |
|
|
|
if return_logits: |
|
return prediction_scores |
|
|
|
masked_lm_loss = None |
|
if labels is not None: |
|
loss_fct = CrossEntropyLoss() |
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
|
if not return_dict: |
|
output = (prediction_scores,) + outputs[2:] |
|
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output |
|
|
|
return MaskedLMOutput( |
|
loss=masked_lm_loss, |
|
logits=prediction_scores, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class Qformer(nn.Module): |
|
def __init__(self, model_args, vision_tower): |
|
super().__init__() |
|
|
|
self.depth = model_args.mm_qformer_depth |
|
self.num_latents = model_args.mm_qformer_latents |
|
self.pretrained = model_args.mm_qformer_pretrained |
|
|
|
self.Qformer, self.query_tokens, self.ln_vision = self.build_Qformer(vision_tower.hidden_size, self.depth, self.num_latents) |
|
|
|
if self.pretrained is not None: |
|
pretrained_dict = torch.load(self.pretrained, map_location="cpu")["model"] |
|
pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith("t5_proj")} |
|
self.load_state_dict(pretrained_dict) |
|
|
|
def build_Qformer(self, vision_width, cross_attention_freq, num_query_token): |
|
encoder_config = BertConfig.from_pretrained("bert-base-uncased") |
|
encoder_config.encoder_width = vision_width |
|
|
|
encoder_config.add_cross_attention = True |
|
encoder_config.cross_attention_freq = cross_attention_freq |
|
encoder_config.query_length = num_query_token |
|
Qformer = BertLMHeadModel(config=encoder_config) |
|
query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size)) |
|
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) |
|
Qformer.cls = None |
|
Qformer.bert.embeddings.word_embeddings = None |
|
Qformer.bert.embeddings.position_embeddings = None |
|
for layer in Qformer.bert.encoder.layer: |
|
layer.output = None |
|
layer.intermediate = None |
|
return Qformer, query_tokens, nn.LayerNorm(vision_width) |
|
|
|
def forward(self, image_features, *args, **kwargs): |
|
x = self.ln_vision(image_features) |
|
image_atts = torch.ones(x.size()[:-1], dtype=torch.long).to(x.device) |
|
|
|
query_tokens = self.query_tokens.expand(x.shape[0], -1, -1) |
|
query_output = self.Qformer.bert( |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=x, |
|
encoder_attention_mask=image_atts, |
|
return_dict=True, |
|
) |
|
|
|
return query_output.last_hidden_state |
|
|
|
@property |
|
def hidden_size(self): |
|
return 768 |
|
|
|
@property |
|
def config(self): |
|
return { |
|
"mm_resampler_type": "qformer", |
|
"mm_qformer_depth": self.depth, |
|
"mm_qformer_latents": self.num_latents, |
|
"mm_qformer_pretrained": self.pretrained, |
|
} |
|
|
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
|
|
def FeedForward(dim, mult=4): |
|
inner_dim = int(dim * mult) |
|
return nn.Sequential( |
|
nn.LayerNorm(dim), |
|
nn.Linear(dim, inner_dim, bias=False), |
|
nn.GELU(), |
|
nn.Linear(inner_dim, dim, bias=False), |
|
) |
|
|
|
|
|
class PerceiverAttention(nn.Module): |
|
def __init__(self, *, dim, dim_head=64, heads=8): |
|
super().__init__() |
|
self.scale = dim_head**-0.5 |
|
self.heads = heads |
|
inner_dim = dim_head * heads |
|
|
|
self.norm_media = nn.LayerNorm(dim) |
|
self.norm_latents = nn.LayerNorm(dim) |
|
|
|
self.to_q = nn.Linear(dim, inner_dim, bias=False) |
|
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) |
|
self.to_out = nn.Linear(inner_dim, dim, bias=False) |
|
|
|
def forward(self, x, latents): |
|
""" |
|
Args: |
|
x (torch.Tensor): image features |
|
shape (b, T, n1, D) |
|
latent (torch.Tensor): latent features |
|
shape (b, T, n2, D) |
|
""" |
|
x = self.norm_media(x) |
|
latents = self.norm_latents(latents) |
|
|
|
h = self.heads |
|
|
|
q = self.to_q(latents) |
|
kv_input = torch.cat((x, latents), dim=-2) |
|
k, v = self.to_kv(kv_input).chunk(2, dim=-1) |
|
q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) |
|
q = q * self.scale |
|
|
|
|
|
sim = einsum("... i d, ... j d -> ... i j", q, k) |
|
sim = sim - sim.amax(dim=-1, keepdim=True).detach() |
|
attn = sim.softmax(dim=-1) |
|
|
|
out = einsum("... i j, ... j d -> ... i d", attn, v) |
|
out = rearrange(out, "b h t n d -> b t n (h d)", h=h) |
|
return self.to_out(out) |
|
|
|
|
|
class PerceiverResamplerModule(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
dim, |
|
depth=6, |
|
dim_head=64, |
|
heads=8, |
|
num_latents=64, |
|
max_num_media=None, |
|
max_num_frames=None, |
|
ff_mult=4, |
|
): |
|
super().__init__() |
|
self.latents = nn.Parameter(torch.randn(num_latents, dim)) |
|
self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None |
|
self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None |
|
|
|
self.layers = nn.ModuleList([]) |
|
for _ in range(depth): |
|
self.layers.append( |
|
nn.ModuleList( |
|
[ |
|
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), |
|
FeedForward(dim=dim, mult=ff_mult) if ff_mult > 0 else nn.Identity(), |
|
] |
|
) |
|
) |
|
|
|
self.norm = nn.LayerNorm(dim) |
|
|
|
def forward(self, x): |
|
""" |
|
Args: |
|
x (torch.Tensor): image features |
|
shape (b, T, F, v, D) |
|
Returns: |
|
shape (b, T, n, D) where n is self.num_latents |
|
""" |
|
b, T, F, v = x.shape[:4] |
|
|
|
|
|
if exists(self.frame_embs): |
|
frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) |
|
x = x + frame_embs |
|
x = rearrange(x, "b T F v d -> b T (F v) d") |
|
if exists(self.media_time_embs): |
|
x = x + self.media_time_embs[:T] |
|
|
|
|
|
latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) |
|
for attn, ff in self.layers: |
|
latents = attn(x, latents) + latents |
|
latents = ff(latents) + latents |
|
return self.norm(latents) |
|
|
|
|
|
class PerceiverResampler(nn.Module): |
|
def __init__(self, model_args, vision_tower): |
|
super().__init__() |
|
|
|
self.depth = model_args.mm_perceiver_depth |
|
self.num_latents = model_args.mm_perceiver_latents |
|
self.ff_mult = model_args.mm_perceiver_ff_mult |
|
self.pretrained = model_args.mm_perceiver_pretrained |
|
|
|
self.perceiver = PerceiverResamplerModule(dim=vision_tower.hidden_size, depth=self.depth, num_latents=self.num_latents, ff_mult=self.ff_mult) |
|
|
|
if self.pretrained is not None: |
|
self.load_state_dict(torch.load(self.pretrained)) |
|
|
|
def forward(self, image_features, *args, **kwargs): |
|
return self.perceiver(image_features[:, None, None]).squeeze(1) |
|
|
|
@property |
|
def config(self): |
|
return { |
|
"mm_resampler_type": "perceiver", |
|
"mm_perceiver_depth": self.depth, |
|
"mm_perceiver_latents": self.num_latents, |
|
"mm_perceiver_ff_mult": self.ff_mult, |
|
"mm_perceiver_pretrained": self.pretrained, |
|
} |
|
|
|
|
|
class MaskedDrop(nn.Module): |
|
def __init__(self, model_args): |
|
super().__init__() |
|
|
|
self.mode = model_args.mm_mask_drop_mode |
|
self.skip_percentage = model_args.mm_mask_drop_skip_percentage |
|
self.ratio = model_args.mm_mask_drop_ratio |
|
self.ratio_upper = model_args.mm_mask_drop_ratio_upper |
|
self.ratio_lower = model_args.mm_mask_drop_ratio_lower |
|
|
|
def forward(self, image_features, *args, **kwargs): |
|
|
|
if not self.training: |
|
return image_features |
|
|
|
if self.skip_percentage > random.random(): |
|
return image_features |
|
|
|
masked_features = [] |
|
|
|
for image_feature in image_features: |
|
num_tokens = image_feature.shape[0] |
|
if self.mode == "fixed": |
|
num_keep = int(num_tokens * self.ratio) |
|
masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0]) |
|
elif self.mode == "range": |
|
num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper)) |
|
masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0]) |
|
elif self.mode == "cls_only": |
|
masked_features.append(image_feature[0:1]) |
|
else: |
|
raise ValueError(f"Unexpected masked drop mode: {self.mode}") |
|
|
|
if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]): |
|
masked_features = torch.stack(masked_features, dim=0) |
|
|
|
return masked_features |
|
|
|
@property |
|
def config(self): |
|
return { |
|
"mm_resampler_type": "masked_drop", |
|
"mm_mask_drop_mode": self.mode, |
|
"mm_mask_drop_skip_percentage": self.skip_percentage, |
|
"mm_mask_drop_ratio": self.ratio, |
|
"mm_mask_drop_ratio_upper": self.ratio_upper, |
|
"mm_mask_drop_ratio_lower": self.ratio_lower, |
|
} |
|
|
|
def random_masking(self, x, len_keep): |
|
""" |
|
Perform per-sample random masking by per-sample shuffling. |
|
Per-sample shuffling is done by argsort random noise. |
|
x: [N, L, D], sequence |
|
""" |
|
N, L, D = x.shape |
|
|
|
noise = torch.rand(N, L, device=x.device) |
|
|
|
|
|
ids_shuffle = torch.argsort(noise, dim=1) |
|
ids_restore = torch.argsort(ids_shuffle, dim=1) |
|
|
|
|
|
ids_keep = ids_shuffle[:, :len_keep] |
|
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
|
|
mask = torch.ones([N, L], device=x.device) |
|
mask[:, :len_keep] = 0 |
|
|
|
mask = torch.gather(mask, dim=1, index=ids_restore) |
|
|
|
return x_masked, mask, ids_restore |
|
|
|
class IdentityMap(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x, *args, **kwargs): |
|
return x |
|
|
|
@property |
|
def config(self): |
|
return {"mm_resampler_type": None} |
|
|
|
|
|
def build_vision_resampler(model_args, delay_load=False, **kwargs): |
|
resampler_type = getattr(model_args, "mm_resampler_type", None) |
|
if resampler_type == "masked_drop": |
|
return MaskedDrop(model_args) |
|
elif resampler_type == "spatial_pool": |
|
return SpatialPool(model_args, **kwargs) |
|
elif resampler_type == "perceiver": |
|
return PerceiverResampler(model_args, **kwargs) |
|
elif resampler_type == "qformer": |
|
return Qformer(model_args, **kwargs) |
|
elif resampler_type is None: |
|
return IdentityMap() |
|
|
|
raise ValueError(f"Unknown resampler type: {resampler_type}") |
|
|
|
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig |
|
|
|
|
|
class CLIPVisionTower(nn.Module): |
|
r""" |
|
A class to represent the CLIP Vision Tower model. |
|
|
|
Attributes : |
|
------------ |
|
- is_loaded (bool): A flag indicating whether the model is loaded. |
|
- vision_tower_name (str): The name of the vision tower model. |
|
- select_layer (int): The layer to select features from. |
|
- select_feature (str): The type of feature to select. |
|
|
|
Methods : |
|
------------ |
|
- `__init__(vision_tower: str, args: Namespace, delay_load: bool = False)`: Initializes the CLIPVisionTower with the given vision tower name and arguments. |
|
- `load_model(device_map: Optional[dict] = None)`: Loads the vision tower model and image processor. |
|
- `feature_select(image_forward_outs: Any) -> torch.Tensor`: Selects features from the image forward outputs based on the specified feature type. |
|
- `forward(images: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor`: Forward pass for the vision tower model. |
|
- `dummy_feature() -> torch.Tensor`: Returns a dummy feature tensor. |
|
- `dtype() -> torch.dtype`: Returns the data type of the vision tower model. |
|
- `device() -> torch.device`: Returns the device of the vision tower model. |
|
- `config() -> Any`: Returns the configuration of the vision tower model. |
|
- `hidden_size() -> int`: Returns the hidden size of the vision tower model. |
|
- `num_patches_per_side() -> int`: Returns the number of patches per side of the image. |
|
- `num_patches() -> int`: Returns the total number of patches in the image. |
|
- `image_size() -> int`: Returns the size of the image. |
|
""" |
|
|
|
def __init__(self, vision_tower, args, delay_load=False): |
|
super().__init__() |
|
|
|
self.is_loaded = False |
|
|
|
self.vision_tower_name = vision_tower |
|
self.select_layer = args.mm_vision_select_layer |
|
self.select_feature = getattr(args, "mm_vision_select_feature", "patch") |
|
|
|
if not delay_load: |
|
rank0_print(f"Loading vision tower: {vision_tower}") |
|
self.load_model() |
|
elif getattr(args, "unfreeze_mm_vision_tower", False): |
|
|
|
rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") |
|
self.load_model() |
|
elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: |
|
rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") |
|
self.load_model() |
|
else: |
|
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) |
|
|
|
def load_model(self, device_map=None): |
|
if self.is_loaded: |
|
rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name)) |
|
return |
|
|
|
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) |
|
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) |
|
self.vision_tower.requires_grad_(False) |
|
|
|
self.is_loaded = True |
|
|
|
def feature_select(self, image_forward_outs): |
|
select_feature_type = self.select_feature |
|
|
|
if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: |
|
select_every_k_layer = len(image_forward_outs.hidden_states) // 4 |
|
image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1) |
|
select_feature_type = select_feature_type.replace("slicefour_", "") |
|
elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]: |
|
select_layers = [-2, -5, -8, -11, 6] |
|
image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1) |
|
select_feature_type = select_feature_type.replace("slice_m25811_f6_", "") |
|
else: |
|
image_features = image_forward_outs.hidden_states[self.select_layer] |
|
|
|
if select_feature_type == "patch": |
|
image_features = image_features[:, 1:] |
|
elif select_feature_type == "cls_patch": |
|
image_features = image_features |
|
else: |
|
raise ValueError(f"Unexpected select feature: {select_feature_type}") |
|
return image_features |
|
|
|
def forward(self, images): |
|
if type(images) is list: |
|
image_features = [] |
|
for image in images: |
|
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) |
|
image_feature = self.feature_select(image_forward_out).to(image.dtype) |
|
image_features.append(image_feature) |
|
else: |
|
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) |
|
image_features = self.feature_select(image_forward_outs).to(images.dtype) |
|
|
|
return image_features |
|
|
|
@property |
|
def dummy_feature(self): |
|
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
|
|
|
@property |
|
def dtype(self): |
|
return self.vision_tower.dtype |
|
|
|
@property |
|
def device(self): |
|
return self.vision_tower.device |
|
|
|
@property |
|
def config(self): |
|
if self.is_loaded: |
|
return self.vision_tower.config |
|
else: |
|
return self.cfg_only |
|
|
|
@property |
|
def hidden_size(self): |
|
_hidden_size = self.config.hidden_size |
|
if "slicefour" in self.select_feature: |
|
_hidden_size *= 4 |
|
if "slice_m25811_f6" in self.select_feature: |
|
_hidden_size *= 5 |
|
return _hidden_size |
|
|
|
@property |
|
def num_patches_per_side(self): |
|
return self.config.image_size // self.config.patch_size |
|
|
|
@property |
|
def num_patches(self): |
|
_num_patches = (self.config.image_size // self.config.patch_size) ** 2 |
|
if "cls_patch" in self.select_feature: |
|
_num_patches += 1 |
|
return _num_patches |
|
|
|
@property |
|
def image_size(self): |
|
return self.config.image_size |
|
|
|
def build_vision_tower(vision_tower_cfg, **kwargs): |
|
vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)) |
|
is_absolute_path_exists = os.path.exists(vision_tower) |
|
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: |
|
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) |
|
|
|
raise ValueError(f"Unknown vision tower: {vision_tower}") |
|
|
|
class InstellaVLMetaModel: |
|
|
|
def __init__(self, config): |
|
super(InstellaVLMetaModel, self).__init__(config) |
|
|
|
if hasattr(config, "mm_vision_tower"): |
|
delay_load = getattr(config, "delay_load", False) |
|
self.vision_tower = build_vision_tower(config, delay_load=delay_load) |
|
self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower) |
|
self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config) |
|
|
|
if "unpad" in getattr(config, "mm_patch_merge_type", ""): |
|
self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype)) |
|
|
|
def get_vision_tower(self): |
|
vision_tower = getattr(self, "vision_tower", None) |
|
if type(vision_tower) is list: |
|
vision_tower = vision_tower[0] |
|
return vision_tower |
|
|
|
def initialize_vision_modules(self, model_args, fsdp=None): |
|
vision_tower = model_args.vision_tower |
|
mm_vision_select_layer = model_args.mm_vision_select_layer |
|
mm_vision_select_feature = model_args.mm_vision_select_feature |
|
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter |
|
mm_patch_merge_type = model_args.mm_patch_merge_type |
|
|
|
self.config.mm_vision_tower = vision_tower |
|
self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "") |
|
|
|
if self.get_vision_tower() is None: |
|
vision_tower = build_vision_tower(model_args) |
|
vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower) |
|
for k, v in vision_resampler.config.items(): |
|
setattr(self.config, k, v) |
|
|
|
if fsdp is not None and len(fsdp) > 0: |
|
self.vision_tower = [vision_tower] |
|
self.vision_resampler = [vision_resampler] |
|
else: |
|
self.vision_tower = vision_tower |
|
self.vision_resampler = vision_resampler |
|
else: |
|
if fsdp is not None and len(fsdp) > 0: |
|
vision_resampler = self.vision_resampler[0] |
|
vision_tower = self.vision_tower[0] |
|
else: |
|
vision_resampler = self.vision_resampler |
|
vision_tower = self.vision_tower |
|
vision_tower.load_model() |
|
|
|
|
|
for p in self.vision_resampler.parameters(): |
|
p.requires_grad = True |
|
|
|
self.config.use_mm_proj = True |
|
self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear") |
|
self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size) |
|
self.config.mm_vision_select_layer = mm_vision_select_layer |
|
self.config.mm_vision_select_feature = mm_vision_select_feature |
|
self.config.mm_patch_merge_type = mm_patch_merge_type |
|
self.config.online_training = model_args.online_training |
|
|
|
if getattr(self, "mm_projector", None) is None: |
|
self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config) |
|
|
|
if "unpad" in mm_patch_merge_type: |
|
embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) |
|
self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std) |
|
else: |
|
|
|
for p in self.mm_projector.parameters(): |
|
p.requires_grad = True |
|
|
|
if pretrain_mm_mlp_adapter is not None: |
|
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu") |
|
|
|
def get_w(weights, keyword): |
|
return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k} |
|
|
|
incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector")) |
|
rank0_print(f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}") |
|
incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False) |
|
rank0_print(f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}") |
|
|
|
if 'tmp-' in pretrain_mm_mlp_adapter: |
|
pretrain_mm_mlp_adapter_folder = os.path.dirname(pretrain_mm_mlp_adapter) |
|
shutil.rmtree(pretrain_mm_mlp_adapter_folder, ignore_errors=True) |
|
|
|
|
|
|
|
def unpad_image(tensor, original_size): |
|
""" |
|
Unpads a PyTorch tensor of a padded and resized image. |
|
|
|
Args: |
|
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. |
|
original_size (tuple): The original size of the image (height, width). |
|
|
|
Returns: |
|
torch.Tensor: The unpadded image tensor. |
|
""" |
|
original_width, original_height = original_size |
|
current_height, current_width = tensor.shape[1:] |
|
|
|
|
|
original_aspect_ratio = original_width / original_height |
|
current_aspect_ratio = current_width / current_height |
|
|
|
|
|
if original_aspect_ratio > current_aspect_ratio: |
|
|
|
scale_factor = current_width / original_width |
|
new_height = int(original_height * scale_factor) |
|
padding = (current_height - new_height) // 2 |
|
unpadded_tensor = tensor[:, padding : current_height - padding, :] |
|
else: |
|
|
|
scale_factor = current_height / original_height |
|
new_width = int(original_width * scale_factor) |
|
padding = (current_width - new_width) // 2 |
|
unpadded_tensor = tensor[:, :, padding : current_width - padding] |
|
|
|
return unpadded_tensor |
|
|
|
|
|
class InstellaVLMetaForCausalLM(ABC): |
|
|
|
@abstractmethod |
|
def get_model(self): |
|
pass |
|
|
|
def get_vision_tower(self): |
|
return self.get_model().get_vision_tower() |
|
|
|
def get_2dPool(self, image_feature): |
|
height = width = self.get_vision_tower().num_patches_per_side |
|
num_frames, num_tokens, num_dim = image_feature.shape |
|
image_feature = image_feature.view(num_frames, height, width, -1) |
|
image_feature = image_feature.permute(0, 3, 1, 2).contiguous() |
|
|
|
if self.config.mm_spatial_pool_mode == "average": |
|
image_feature = nn.functional.avg_pool2d(image_feature, self.config.mm_spatial_pool_stride) |
|
elif self.config.mm_spatial_pool_mode == "max": |
|
image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride) |
|
elif self.config.mm_spatial_pool_mode == "bilinear": |
|
height, weight = image_feature.shape[2:] |
|
scaled_shape = [math.ceil(height / 2), math.ceil(weight / 2)] |
|
image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear') |
|
|
|
else: |
|
raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}") |
|
image_feature = image_feature.permute(0, 2, 3, 1) |
|
image_feature = image_feature.view(num_frames, -1, num_dim) |
|
return image_feature |
|
|
|
def encode_images(self, images): |
|
image_features = self.get_model().get_vision_tower()(images) |
|
|
|
image_features = self.get_model().mm_projector(image_features) |
|
return image_features |
|
|
|
def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None): |
|
videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images) |
|
per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0) |
|
all_videos_or_images_features = [] |
|
|
|
for idx, feat in enumerate(per_videos_or_images_features): |
|
feat = self.get_model().mm_projector(feat) |
|
if idx in video_idx_in_batch: |
|
feat = self.get_2dPool(feat) |
|
all_videos_or_images_features.append(feat) |
|
return all_videos_or_images_features |
|
|
|
def add_token_per_grid(self, image_feature): |
|
resize_h = int(math.sqrt(image_feature.shape[1])) |
|
num_frames = image_feature.shape[0] |
|
image_feature = image_feature.view(num_frames, 1, resize_h, resize_h, -1) |
|
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() |
|
image_feature = image_feature.flatten(1, 2).flatten(2, 3) |
|
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) |
|
image_feature = image_feature.flatten(1, 2).transpose(0, 1) |
|
return image_feature |
|
|
|
def add_token_per_frame(self, image_feature): |
|
image_feature = image_feature.permute(2, 0, 1).contiguous() |
|
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) |
|
image_feature = image_feature.permute(1, 2, 0).contiguous() |
|
return image_feature |
|
|
|
def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None): |
|
vision_tower = self.get_vision_tower() |
|
|
|
if vision_tower is None or images is None or input_ids.shape[1] == 1: |
|
return input_ids, position_ids, attention_mask, past_key_values, None, labels |
|
|
|
if isinstance(modalities, str): |
|
modalities = [modalities] |
|
|
|
if type(images) is list or images.ndim == 5: |
|
if type(images) is list: |
|
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] |
|
|
|
video_idx_in_batch = [] |
|
for _ in range(len(modalities)): |
|
if modalities[_] == "video": |
|
video_idx_in_batch.append(_) |
|
|
|
|
|
|
|
images_list = [] |
|
for image in images: |
|
if image.ndim == 4: |
|
images_list.append(image) |
|
else: |
|
images_list.append(image.unsqueeze(0)) |
|
|
|
|
|
concat_images = torch.cat([image for image in images_list], dim=0) |
|
split_sizes = [image.shape[0] for image in images_list] |
|
encoded_image_features = self.encode_images(concat_images) |
|
|
|
|
|
|
|
|
|
|
|
encoded_image_features = torch.split(encoded_image_features, split_sizes) |
|
image_features = [] |
|
for idx, image_feat in enumerate(encoded_image_features): |
|
if idx in video_idx_in_batch: |
|
image_features.append(self.get_2dPool(image_feat)) |
|
else: |
|
image_features.append(image_feat) |
|
|
|
|
|
|
|
mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") |
|
image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") |
|
|
|
if mm_patch_merge_type == "flat": |
|
image_features = [x.flatten(0, 1) for x in image_features] |
|
|
|
elif mm_patch_merge_type.startswith("spatial"): |
|
new_image_features = [] |
|
for image_idx, image_feature in enumerate(image_features): |
|
|
|
|
|
|
|
|
|
|
|
if image_idx in video_idx_in_batch: |
|
|
|
if self.config.mm_newline_position == "grid": |
|
|
|
image_feature = self.add_token_per_grid(image_feature) |
|
|
|
new_image_features.append(image_feature) |
|
elif self.config.mm_newline_position == "frame": |
|
|
|
image_feature = self.add_token_per_frame(image_feature) |
|
|
|
new_image_features.append(image_feature.flatten(0, 1)) |
|
|
|
elif self.config.mm_newline_position == "one_token": |
|
|
|
image_feature = image_feature.flatten(0, 1) |
|
if 'unpad' in mm_patch_merge_type: |
|
image_feature = torch.cat(( |
|
image_feature, |
|
self.model.image_newline[None].to(image_feature.device) |
|
), dim=0) |
|
new_image_features.append(image_feature) |
|
elif self.config.mm_newline_position == "no_token": |
|
new_image_features.append(image_feature.flatten(0, 1)) |
|
else: |
|
raise ValueError(f"Unexpected mm_newline_position: {self.config.mm_newline_position}") |
|
|
|
|
|
elif image_feature.shape[0] > 1: |
|
base_image_feature = image_feature[0] |
|
image_feature = image_feature[1:] |
|
height = width = self.get_vision_tower().num_patches_per_side |
|
|
|
assert height * width == base_image_feature.shape[0] |
|
|
|
if "anyres_max" in image_aspect_ratio: |
|
matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio) |
|
if matched_anyres_max_num_patches: |
|
max_num_patches = int(matched_anyres_max_num_patches.group(1)) |
|
|
|
if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: |
|
if hasattr(self.get_vision_tower(), "image_size"): |
|
vision_tower_image_size = self.get_vision_tower().image_size |
|
else: |
|
raise ValueError("vision_tower_image_size is not found in the vision tower.") |
|
try: |
|
num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size) |
|
except Exception as e: |
|
rank0_print(f"Error: {e}") |
|
num_patch_width, num_patch_height = 2, 2 |
|
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) |
|
else: |
|
image_feature = image_feature.view(2, 2, height, width, -1) |
|
|
|
if "maxpool2x2" in mm_patch_merge_type: |
|
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() |
|
image_feature = image_feature.flatten(1, 2).flatten(2, 3) |
|
image_feature = nn.functional.max_pool2d(image_feature, 2) |
|
image_feature = image_feature.flatten(1, 2).transpose(0, 1) |
|
elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches: |
|
unit = image_feature.shape[2] |
|
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() |
|
image_feature = image_feature.flatten(1, 2).flatten(2, 3) |
|
image_feature = unpad_image(image_feature, image_sizes[image_idx]) |
|
c, h, w = image_feature.shape |
|
times = math.sqrt(h * w / (max_num_patches * unit**2)) |
|
if times > 1.1: |
|
image_feature = image_feature[None] |
|
image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0] |
|
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) |
|
image_feature = image_feature.flatten(1, 2).transpose(0, 1) |
|
elif "unpad" in mm_patch_merge_type: |
|
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() |
|
image_feature = image_feature.flatten(1, 2).flatten(2, 3) |
|
image_feature = unpad_image(image_feature, image_sizes[image_idx]) |
|
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) |
|
image_feature = image_feature.flatten(1, 2).transpose(0, 1) |
|
else: |
|
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() |
|
image_feature = image_feature.flatten(0, 3) |
|
if "nobase" in mm_patch_merge_type: |
|
pass |
|
else: |
|
image_feature = torch.cat((base_image_feature, image_feature), dim=0) |
|
else: |
|
image_feature = image_feature[0] |
|
if "unpad" in mm_patch_merge_type: |
|
image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0) |
|
|
|
new_image_features.append(image_feature) |
|
image_features = new_image_features |
|
else: |
|
raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") |
|
else: |
|
image_features = self.encode_images(images) |
|
|
|
|
|
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False): |
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
|
|
_labels = labels |
|
_position_ids = position_ids |
|
_attention_mask = attention_mask |
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
|
else: |
|
attention_mask = attention_mask.bool() |
|
if position_ids is None: |
|
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) |
|
if labels is None: |
|
labels = torch.full_like(input_ids, IGNORE_INDEX) |
|
|
|
|
|
_input_ids = input_ids |
|
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] |
|
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] |
|
|
|
new_input_embeds = [] |
|
new_labels = [] |
|
cur_image_idx = 0 |
|
|
|
for batch_idx, cur_input_ids in enumerate(input_ids): |
|
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() |
|
|
|
if num_images == 0: |
|
try: |
|
cur_image_features = image_features[cur_image_idx] |
|
except IndexError: |
|
try: |
|
cur_image_features = image_features[cur_image_idx - 1] |
|
except IndexError: |
|
pass |
|
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) |
|
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) |
|
new_input_embeds.append(cur_input_embeds) |
|
new_labels.append(labels[batch_idx]) |
|
cur_image_idx += 1 |
|
continue |
|
|
|
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] |
|
cur_input_ids_noim = [] |
|
cur_labels = labels[batch_idx] |
|
cur_labels_noim = [] |
|
for i in range(len(image_token_indices) - 1): |
|
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) |
|
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]) |
|
split_sizes = [x.shape[0] for x in cur_labels_noim] |
|
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) |
|
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) |
|
cur_new_input_embeds = [] |
|
cur_new_labels = [] |
|
|
|
for i in range(num_images + 1): |
|
cur_new_input_embeds.append(cur_input_embeds_no_im[i]) |
|
cur_new_labels.append(cur_labels_noim[i]) |
|
if i < num_images: |
|
try: |
|
cur_image_features = image_features[cur_image_idx] |
|
except IndexError: |
|
cur_image_features = image_features[cur_image_idx - 1] |
|
cur_image_idx += 1 |
|
cur_new_input_embeds.append(cur_image_features) |
|
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) |
|
|
|
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] |
|
|
|
|
|
cur_new_input_embeds = torch.cat(cur_new_input_embeds) |
|
cur_new_labels = torch.cat(cur_new_labels) |
|
|
|
new_input_embeds.append(cur_new_input_embeds) |
|
new_labels.append(cur_new_labels) |
|
|
|
|
|
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None) |
|
|
|
|
|
new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)] |
|
new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
max_len = max(x.shape[0] for x in new_input_embeds) |
|
batch_size = len(new_input_embeds) |
|
|
|
new_input_embeds_padded = [] |
|
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) |
|
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) |
|
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) |
|
|
|
|
|
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): |
|
cur_len = cur_new_embed.shape[0] |
|
if getattr(self.config, "tokenizer_padding_side", "right") == "left": |
|
new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0)) |
|
if cur_len > 0: |
|
new_labels_padded[i, -cur_len:] = cur_new_labels |
|
attention_mask[i, -cur_len:] = True |
|
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) |
|
else: |
|
new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)) |
|
if cur_len > 0: |
|
new_labels_padded[i, :cur_len] = cur_new_labels |
|
attention_mask[i, :cur_len] = True |
|
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) |
|
|
|
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) |
|
|
|
|
|
if _labels is None: |
|
new_labels = None |
|
else: |
|
new_labels = new_labels_padded |
|
|
|
if _attention_mask is None: |
|
attention_mask = None |
|
else: |
|
attention_mask = attention_mask.to(dtype=_attention_mask.dtype) |
|
|
|
if _position_ids is None: |
|
position_ids = None |
|
if getattr(self.config, "use_pos_skipping", False) and self.training: |
|
position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device) |
|
split_position = random.randint(0, new_input_embeds.size(1)) |
|
left_add = random.randint(0, self.config.pos_skipping_range) |
|
right_add = random.randint(left_add, self.config.pos_skipping_range) |
|
position_ids[:, :split_position] += left_add |
|
position_ids[:, split_position:] += right_add |
|
|
|
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels |
|
|
|
def initialize_vision_tokenizer(self, model_args, tokenizer): |
|
if model_args.mm_use_im_patch_token: |
|
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) |
|
self.resize_token_embeddings(len(tokenizer)) |
|
|
|
if model_args.mm_use_im_start_end: |
|
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) |
|
self.resize_token_embeddings(len(tokenizer)) |
|
|
|
if num_new_tokens > 0: |
|
input_embeddings = self.get_input_embeddings().weight.data |
|
output_embeddings = self.get_output_embeddings().weight.data |
|
|
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
|
|
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg |
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg |
|
|
|
if model_args.tune_mm_mlp_adapter: |
|
for p in self.get_input_embeddings().parameters(): |
|
p.requires_grad = True |
|
for p in self.get_output_embeddings().parameters(): |
|
p.requires_grad = False |
|
|
|
if model_args.pretrain_mm_mlp_adapter: |
|
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu") |
|
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] |
|
assert num_new_tokens == 2 |
|
if input_embeddings.shape == embed_tokens_weight.shape: |
|
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] |
|
elif embed_tokens_weight.shape[0] == num_new_tokens: |
|
input_embeddings[-num_new_tokens:] = embed_tokens_weight |
|
else: |
|
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") |
|
elif model_args.mm_use_im_patch_token: |
|
if model_args.tune_mm_mlp_adapter: |
|
for p in self.get_input_embeddings().parameters(): |
|
p.requires_grad = False |
|
for p in self.get_output_embeddings().parameters(): |
|
p.requires_grad = False |
|
|
|
class InstellaVLConfig(OlmoConfig): |
|
""" |
|
Configuration class for the InstellaVL model. |
|
Attributes: |
|
model_type (str): The type of the model, set to "instellavl". |
|
""" |
|
|
|
model_type = "instellavl" |
|
|
|
|
|
def disable_torch_init(): |
|
r""" |
|
Disable the redundant torch default initialization to accelerate model creation. |
|
""" |
|
import torch |
|
|
|
setattr(torch.nn.Linear, "reset_parameters", lambda self: None) |
|
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) |
|
|
|
|
|
class InstellaVLModel(InstellaVLMetaModel, OlmoModel): |
|
config_class = InstellaVLConfig |
|
|
|
def __init__(self, config: OlmoConfig): |
|
super(InstellaVLModel, self).__init__(config) |
|
|
|
|
|
class InstellaVLForCausalLM(OlmoForCausalLM, InstellaVLMetaForCausalLM): |
|
r""" |
|
InstellaVLForCausalLM is a class that extends OlmoForCausalLM and InstellaVLMetaForCausalLM to provide |
|
a language model with multimodal capabilities, specifically for handling images along with text. |
|
|
|
1. Attributes: |
|
- config_class (type): The configuration class to use for this model. |
|
- model (InstellaVLModel): The underlying model. |
|
- lm_head (nn.Linear): The linear layer for language modeling head. |
|
|
|
2. Methods: |
|
|
|
1. `__init__(config: InstellaVLConfig)`: |
|
Initializes the InstellaVLForCausalLM model with the given configuration. |
|
|
|
2. `get_model() -> InstellaVLModel`: |
|
Returns the underlying model. |
|
|
|
3. `forward() -> Union[Tuple, CausalLMOutputWithPast]`: |
|
Performs a forward pass through the model. |
|
|
|
4. `generate() -> Union[GenerateOutput, torch.LongTensor]`: |
|
Generates text based on the input. |
|
|
|
5. `prepare_inputs_for_generation(input_ids: torch.LongTensor,) -> dict`: |
|
Prepares inputs for text generation. |
|
|
|
""" |
|
|
|
config_class = InstellaVLConfig |
|
|
|
def __init__(self, config: OlmoConfig): |
|
r""" |
|
Initializes the InstellaVLForCausalLM model. |
|
|
|
Args: |
|
- config (OlmoConfig): Configuration object for the model. |
|
|
|
Attributes: |
|
- model (InstellaVLModel): The main model instance. |
|
- lm_head (torch.nn.Linear): Linear layer that maps hidden states to vocabulary size. |
|
""" |
|
super(OlmoForCausalLM, self).__init__(config) |
|
disable_torch_init() |
|
config.model_type = "instellavl" |
|
self.model = InstellaVLModel(config) |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_model(self): |
|
return self.model |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
images: Optional[torch.FloatTensor] = None, |
|
image_sizes: Optional[List[List[int]]] = None, |
|
return_dict: Optional[bool] = None, |
|
modalities: Optional[List[str]] = ["image"], |
|
cache_position=None, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
r""" |
|
Args: |
|
- input_ids (torch.LongTensor, optional): Input token IDs. |
|
- attention_mask (torch.Tensor, optional): Attention mask. |
|
- position_ids (torch.LongTensor, optional): Position IDs. |
|
- past_key_values (List[torch.FloatTensor], optional): Past key values for caching. |
|
- inputs_embeds (torch.FloatTensor, optional): Input embeddings. |
|
- labels (torch.LongTensor, optional): Labels for language modeling. |
|
- use_cache (bool, optional): Whether to use cache. |
|
- output_attentions (bool, optional): Whether to output attentions. |
|
- output_hidden_states (bool, optional): Whether to output hidden states. |
|
- images (torch.FloatTensor, optional): Input images. |
|
- image_sizes (List[List[int]], optional): Sizes of input images. |
|
- return_dict (bool, optional): Whether to return a dictionary. |
|
- modalities (List[str], optional): List of modalities. |
|
- cache_position (optional): Cache position. |
|
|
|
Returns: |
|
Union[Tuple, CausalLMOutputWithPast]: The output of the forward pass. |
|
""" |
|
if inputs_embeds is None: |
|
( |
|
input_ids, |
|
position_ids, |
|
attention_mask, |
|
past_key_values, |
|
inputs_embeds, |
|
labels |
|
) = self.prepare_inputs_labels_for_multimodal( |
|
input_ids, |
|
position_ids, |
|
attention_mask, |
|
past_key_values, |
|
labels, |
|
images, |
|
modalities, |
|
image_sizes |
|
) |
|
|
|
return super().forward( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
labels=labels, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict |
|
) |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
inputs: Optional[torch.Tensor] = None, |
|
images: Optional[torch.Tensor] = None, |
|
image_sizes: Optional[torch.Tensor] = None, |
|
modalities: Optional[List[str]] = ["image"], |
|
**kwargs, |
|
) -> Union[GenerateOutput, torch.LongTensor]: |
|
r""" |
|
Args: |
|
- inputs (torch.Tensor, optional): Input tensor. |
|
- images (torch.Tensor, optional): Input images. |
|
- image_sizes (torch.Tensor, optional): Sizes of input images. |
|
- modalities (List[str], optional): List of modalities. |
|
- **kwargs: Additional arguments. |
|
|
|
Returns: |
|
Union[GenerateOutput, torch.LongTensor]: The generated text. |
|
""" |
|
modalities = kwargs.pop("modalities", None) if "modalities" in kwargs and modalities is None else modalities |
|
position_ids = kwargs.pop("position_ids", None) |
|
attention_mask = kwargs.pop("attention_mask", None) |
|
if "inputs_embeds" in kwargs: |
|
raise NotImplementedError("`inputs_embeds` is not supported") |
|
|
|
if images is not None: |
|
( |
|
inputs, |
|
position_ids, |
|
attention_mask, |
|
_, |
|
inputs_embeds, |
|
_ |
|
) = self.prepare_inputs_labels_for_multimodal( |
|
inputs, |
|
position_ids, |
|
attention_mask, |
|
None, |
|
None, |
|
images, |
|
image_sizes=image_sizes |
|
) |
|
else: |
|
inputs_embeds = self.get_model().embed_tokens(inputs) |
|
return super().generate( |
|
position_ids=position_ids, |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds, |
|
**kwargs |
|
) |
|
|
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, |
|
inputs_embeds=None, **kwargs): |
|
r""" |
|
Args: |
|
- input_ids (torch.LongTensor): Input token IDs. |
|
- past_key_values (List[torch.FloatTensor], optional): Past key values for caching. |
|
- inputs_embeds (torch.FloatTensor, optional): Input embeddings. |
|
- **kwargs: Additional arguments. |
|
|
|
Returns: |
|
dict: Prepared inputs for generation. |
|
""" |
|
images = kwargs.pop("images", None) |
|
image_sizes = kwargs.pop("image_sizes", None) |
|
inputs = super().prepare_inputs_for_generation( |
|
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs |
|
) |
|
if images is not None: |
|
inputs['images'] = images |
|
if image_sizes is not None: |
|
inputs['image_sizes'] = image_sizes |
|
return inputs |
|
|
|
AutoConfig.register("instellavl", InstellaVLConfig) |
|
AutoModelForCausalLM.register(InstellaVLConfig, InstellaVLForCausalLM) |
|
|