Patching flash-attn
Browse files- modeling_aero.py +75 -1
modeling_aero.py
CHANGED
@@ -30,9 +30,16 @@ from transformers.modeling_outputs import BaseModelOutput, ModelOutput
|
|
30 |
from transformers.modeling_utils import PreTrainedModel
|
31 |
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
32 |
from transformers.utils import logging
|
|
|
33 |
|
34 |
from .configuration_aero import AeroConfig
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
logger = logging.get_logger(__name__)
|
37 |
|
38 |
|
@@ -78,6 +85,72 @@ class AeroCausalLMOutputWithPast(ModelOutput):
|
|
78 |
audio_hidden_states: Optional[torch.FloatTensor] = None
|
79 |
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
class AeroAudioMultiModalProjector(nn.Module):
|
83 |
def __init__(self, config: AeroConfig):
|
@@ -136,7 +209,8 @@ class AeroPreTrainedModel(PreTrainedModel):
|
|
136 |
class AeroForConditionalGeneration(AeroPreTrainedModel, GenerationMixin):
|
137 |
def __init__(self, config: AeroConfig):
|
138 |
super().__init__(config)
|
139 |
-
|
|
|
140 |
self.audio_tower_type = config.audio_config.model_type
|
141 |
self.audio_tower = AutoModel.from_config(config.audio_config)
|
142 |
self.audio_modal_projector = AeroAudioMultiModalProjector(config)
|
|
|
30 |
from transformers.modeling_utils import PreTrainedModel
|
31 |
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
32 |
from transformers.utils import logging
|
33 |
+
from transformers.models.qwen2_audio.modeling_qwen2_audio import Qwen2AudioFlashAttention2
|
34 |
|
35 |
from .configuration_aero import AeroConfig
|
36 |
|
37 |
+
|
38 |
+
try:
|
39 |
+
from flash_attn import flash_attn_func
|
40 |
+
except ImportError:
|
41 |
+
print("flash_attn not installed. Please install flash-attn to use flash-attn for audio tower")
|
42 |
+
|
43 |
logger = logging.get_logger(__name__)
|
44 |
|
45 |
|
|
|
85 |
audio_hidden_states: Optional[torch.FloatTensor] = None
|
86 |
|
87 |
|
88 |
+
# Original Flash attn in transformers for Qwen2Audio Encoder is buggy
|
89 |
+
# patch the function with this one
|
90 |
+
def qwen2_audio_flash_attn_forward(
|
91 |
+
self,
|
92 |
+
hidden_states: torch.Tensor,
|
93 |
+
key_value_states= None,
|
94 |
+
past_key_value= None,
|
95 |
+
attention_mask = None,
|
96 |
+
layer_head_mask = None,
|
97 |
+
output_attentions: bool = False,
|
98 |
+
cache_position = None,
|
99 |
+
):
|
100 |
+
# Qwen2AudioFlashAttention2 attention does not support output_attentions
|
101 |
+
if output_attentions:
|
102 |
+
raise ValueError("Qwen2AudioFlashAttention2 attention does not support output_attentions")
|
103 |
+
|
104 |
+
bsz, tgt_len, _ = hidden_states.size()
|
105 |
+
|
106 |
+
# get query proj
|
107 |
+
query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim))
|
108 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
109 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
110 |
+
|
111 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
|
112 |
+
# We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
|
113 |
+
key_states = key_states.transpose(1, 2)
|
114 |
+
value_states = value_states.transpose(1, 2)
|
115 |
+
|
116 |
+
causal_mask = attention_mask
|
117 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
118 |
+
causal_mask = attention_mask[:, : key_states.shape[-2]]
|
119 |
+
|
120 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
121 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
122 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
123 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
124 |
+
# in fp32. (LlamaRMSNorm handles it correctly)
|
125 |
+
|
126 |
+
input_dtype = query_states.dtype
|
127 |
+
if input_dtype == torch.float32:
|
128 |
+
if torch.is_autocast_enabled():
|
129 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
130 |
+
# Handle the case where the model is quantized
|
131 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
132 |
+
target_dtype = self.config._pre_quantization_dtype
|
133 |
+
else:
|
134 |
+
target_dtype = self.q_proj.weight.dtype
|
135 |
+
|
136 |
+
query_states = query_states.to(target_dtype)
|
137 |
+
key_states = key_states.to(target_dtype)
|
138 |
+
value_states = value_states.to(target_dtype)
|
139 |
+
dropout=self.dropout if self.training else 0.0
|
140 |
+
attn_output = flash_attn_func(
|
141 |
+
query_states, key_states, value_states, dropout, softmax_scale=None, causal=self.is_causal
|
142 |
+
)
|
143 |
+
|
144 |
+
attn_output = attn_output.reshape(bsz, tgt_len, -1)
|
145 |
+
attn_output = self.out_proj(attn_output)
|
146 |
+
|
147 |
+
if not output_attentions:
|
148 |
+
attn_weights = None
|
149 |
+
|
150 |
+
return attn_output, attn_weights, None
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
|
155 |
class AeroAudioMultiModalProjector(nn.Module):
|
156 |
def __init__(self, config: AeroConfig):
|
|
|
209 |
class AeroForConditionalGeneration(AeroPreTrainedModel, GenerationMixin):
|
210 |
def __init__(self, config: AeroConfig):
|
211 |
super().__init__(config)
|
212 |
+
if config._attn_implementation == "flash_attention_2":
|
213 |
+
Qwen2AudioFlashAttention2.forward = qwen2_audio_flash_attn_forward
|
214 |
self.audio_tower_type = config.audio_config.model_type
|
215 |
self.audio_tower = AutoModel.from_config(config.audio_config)
|
216 |
self.audio_modal_projector = AeroAudioMultiModalProjector(config)
|