Aero-1-Audio / modeling_aero.py
kcz358's picture
Patching flash-attn
2ea8a95 verified
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Aero model."""
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from transformers import AutoConfig, AutoModel
from transformers.activations import ACT2FN
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import BaseModelOutput, ModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.utils import logging
from transformers.models.qwen2_audio.modeling_qwen2_audio import Qwen2AudioFlashAttention2
from .configuration_aero import AeroConfig
try:
from flash_attn import flash_attn_func
except ImportError:
print("flash_attn not installed. Please install flash-attn to use flash-attn for audio tower")
logger = logging.get_logger(__name__)
@dataclass
# Copied from transformers.models.llava_next_video.modeling_llava_next_video.LlavaNextVideoCausalLMOutputWithPast with LlavaNextVideo->LlavaOnevision
class AeroCausalLMOutputWithPast(ModelOutput):
"""
Base class for Aero causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
audio_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor`.
audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
audio_hidden_states: Optional[torch.FloatTensor] = None
# Original Flash attn in transformers for Qwen2Audio Encoder is buggy
# patch the function with this one
def qwen2_audio_flash_attn_forward(
self,
hidden_states: torch.Tensor,
key_value_states= None,
past_key_value= None,
attention_mask = None,
layer_head_mask = None,
output_attentions: bool = False,
cache_position = None,
):
# Qwen2AudioFlashAttention2 attention does not support output_attentions
if output_attentions:
raise ValueError("Qwen2AudioFlashAttention2 attention does not support output_attentions")
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim))
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
# We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
causal_mask = attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, : key_states.shape[-2]]
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
dropout=self.dropout if self.training else 0.0
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=None, causal=self.is_causal
)
attn_output = attn_output.reshape(bsz, tgt_len, -1)
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, None
class AeroAudioMultiModalProjector(nn.Module):
def __init__(self, config: AeroConfig):
super().__init__()
self.linear = nn.Linear(
config.audio_config.d_model, config.text_config.hidden_size, bias=True
)
def forward(self, audio_features):
hidden_states = self.linear(audio_features)
return hidden_states
class AeroPreTrainedModel(PreTrainedModel):
config_class = AeroConfig
base_model_prefix = "language_model"
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_static_cache = (
False # Qwen2 doesn't but llava has no reasons to not support
)
_supports_quantized_cache = True
_supports_sdpa = True
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextPreTrainedModel._init_weights
def _init_weights(self, module):
# important: this ported version of LlavaNext isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/LLaVA/tree/main/llava_next should serve for that purpose
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class AeroForConditionalGeneration(AeroPreTrainedModel, GenerationMixin):
def __init__(self, config: AeroConfig):
super().__init__(config)
if config._attn_implementation == "flash_attention_2":
Qwen2AudioFlashAttention2.forward = qwen2_audio_flash_attn_forward
self.audio_tower_type = config.audio_config.model_type
self.audio_tower = AutoModel.from_config(config.audio_config)
self.audio_modal_projector = AeroAudioMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
self.post_init()
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder
def get_decoder(self):
return self.language_model.get_decoder()
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.tie_weights
def tie_weights(self):
return self.language_model.tie_weights()
def prepare_inputs_for_qwen_audio_encoder(
self,
audio_values: torch.Tensor,
audio_attention_mask: torch.Tensor,
audio_feat_lengths: torch.FloatTensor,
audio_output_lengths: torch.FloatTensor,
):
batch_size, _, max_mel_seq_len = audio_values.shape
max_seq_len = (max_mel_seq_len - 2) // 2 + 1
# Create a sequence tensor of shape (batch_size, max_seq_len)
seq_range = (
torch.arange(
0,
max_seq_len,
dtype=audio_feat_lengths.dtype,
device=audio_feat_lengths.device,
)
.unsqueeze(0)
.expand(batch_size, max_seq_len)
)
lengths_expand = audio_feat_lengths.unsqueeze(1).expand(batch_size, max_seq_len)
# Create mask
padding_mask = seq_range >= lengths_expand
audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
batch_size, 1, max_seq_len, max_seq_len
)
audio_attention_mask = audio_attention_mask_.to(
dtype=self.audio_tower.conv1.weight.dtype,
device=self.audio_tower.conv1.weight.device,
)
audio_attention_mask[audio_attention_mask_] = float("-inf")
inputs = {
"input_features": audio_values,
"attention_mask": audio_attention_mask,
}
return inputs
def prepare_scattered_audio_values(
self,
audio_features,
audio_output_lengths,
):
# Audio feature is in (bs, max_seq_len, hidden_size)
# If directly masked scatter, the embed will be place one by one (order is incorret)
# We remove the padded values first
unpadded_audio_features = [
audio_feat[:audio_output_length]
for audio_feat, audio_output_length in zip(
audio_features, audio_output_lengths
)
]
# Concat the audio features
# Should exactly have audio_mask.sum() values
unpadded_audio_features = torch.concatenate(unpadded_audio_features, dim=0)
return unpadded_audio_features
def forward(
self,
input_ids: torch.LongTensor = None,
audio_values: torch.FloatTensor = None,
audio_attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: int = 0,
) -> Union[Tuple, AeroCausalLMOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
# Embed audio features
if audio_values is not None:
(
audio_feat_lengths,
audio_output_lengths,
) = self.audio_tower._get_feat_extract_output_lengths(
audio_attention_mask.sum(-1)
)
inputs = self.prepare_inputs_for_qwen_audio_encoder(
audio_values=audio_values,
audio_attention_mask=audio_attention_mask,
audio_feat_lengths=audio_feat_lengths,
audio_output_lengths=audio_output_lengths,
)
audio_outputs = self.audio_tower(**inputs)
selected_audio_feature = audio_outputs.last_hidden_state
audio_features = self.audio_modal_projector(selected_audio_feature)
n_audio_tokens = (input_ids == self.config.audio_token_index).sum().item()
n_audio_features = audio_output_lengths.sum()
if n_audio_tokens != n_audio_features:
raise ValueError(
f"Audio features and image tokens do not match: tokens: {n_audio_tokens}, features {n_audio_features}"
)
audio_mask = (
(input_ids == self.config.audio_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
audio_features = audio_features.to(
inputs_embeds.device, inputs_embeds.dtype
)
audio_features = self.prepare_scattered_audio_values(
audio_features, audio_output_lengths
)
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
labels=labels,
)
logits = outputs[0]
loss = outputs.get("loss", None)
if labels is not None and loss is None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
logits.device
)
shift_logits = logits[..., :-1, :][
shift_attention_mask.to(logits.device) != 0
].contiguous()
shift_labels = labels[..., 1:][
shift_attention_mask.to(labels.device) != 0
].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1).to(shift_logits.device),
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return AeroCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
audio_hidden_states=audio_features if audio_values is not None else None,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
attention_mask=None,
cache_position=None,
logits_to_keep=None,
audio_values=None,
audio_attention_mask=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
if cache_position[0] == 0:
model_inputs["audio_values"] = audio_values
model_inputs["audio_attention_mask"] = audio_attention_mask
return model_inputs