kcz358 commited on
Commit
2ea8a95
·
verified ·
1 Parent(s): 5c89dba

Patching flash-attn

Browse files
Files changed (1) hide show
  1. modeling_aero.py +75 -1
modeling_aero.py CHANGED
@@ -30,9 +30,16 @@ from transformers.modeling_outputs import BaseModelOutput, ModelOutput
30
  from transformers.modeling_utils import PreTrainedModel
31
  from transformers.models.auto import AutoModel, AutoModelForCausalLM
32
  from transformers.utils import logging
 
33
 
34
  from .configuration_aero import AeroConfig
35
 
 
 
 
 
 
 
36
  logger = logging.get_logger(__name__)
37
 
38
 
@@ -78,6 +85,72 @@ class AeroCausalLMOutputWithPast(ModelOutput):
78
  audio_hidden_states: Optional[torch.FloatTensor] = None
79
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  class AeroAudioMultiModalProjector(nn.Module):
83
  def __init__(self, config: AeroConfig):
@@ -136,7 +209,8 @@ class AeroPreTrainedModel(PreTrainedModel):
136
  class AeroForConditionalGeneration(AeroPreTrainedModel, GenerationMixin):
137
  def __init__(self, config: AeroConfig):
138
  super().__init__(config)
139
-
 
140
  self.audio_tower_type = config.audio_config.model_type
141
  self.audio_tower = AutoModel.from_config(config.audio_config)
142
  self.audio_modal_projector = AeroAudioMultiModalProjector(config)
 
30
  from transformers.modeling_utils import PreTrainedModel
31
  from transformers.models.auto import AutoModel, AutoModelForCausalLM
32
  from transformers.utils import logging
33
+ from transformers.models.qwen2_audio.modeling_qwen2_audio import Qwen2AudioFlashAttention2
34
 
35
  from .configuration_aero import AeroConfig
36
 
37
+
38
+ try:
39
+ from flash_attn import flash_attn_func
40
+ except ImportError:
41
+ print("flash_attn not installed. Please install flash-attn to use flash-attn for audio tower")
42
+
43
  logger = logging.get_logger(__name__)
44
 
45
 
 
85
  audio_hidden_states: Optional[torch.FloatTensor] = None
86
 
87
 
88
+ # Original Flash attn in transformers for Qwen2Audio Encoder is buggy
89
+ # patch the function with this one
90
+ def qwen2_audio_flash_attn_forward(
91
+ self,
92
+ hidden_states: torch.Tensor,
93
+ key_value_states= None,
94
+ past_key_value= None,
95
+ attention_mask = None,
96
+ layer_head_mask = None,
97
+ output_attentions: bool = False,
98
+ cache_position = None,
99
+ ):
100
+ # Qwen2AudioFlashAttention2 attention does not support output_attentions
101
+ if output_attentions:
102
+ raise ValueError("Qwen2AudioFlashAttention2 attention does not support output_attentions")
103
+
104
+ bsz, tgt_len, _ = hidden_states.size()
105
+
106
+ # get query proj
107
+ query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim))
108
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
109
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
110
+
111
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
112
+ # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
113
+ key_states = key_states.transpose(1, 2)
114
+ value_states = value_states.transpose(1, 2)
115
+
116
+ causal_mask = attention_mask
117
+ if attention_mask is not None: # no matter the length, we just slice it
118
+ causal_mask = attention_mask[:, : key_states.shape[-2]]
119
+
120
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
121
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
122
+ # cast them back in the correct dtype just to be sure everything works as expected.
123
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
124
+ # in fp32. (LlamaRMSNorm handles it correctly)
125
+
126
+ input_dtype = query_states.dtype
127
+ if input_dtype == torch.float32:
128
+ if torch.is_autocast_enabled():
129
+ target_dtype = torch.get_autocast_gpu_dtype()
130
+ # Handle the case where the model is quantized
131
+ elif hasattr(self.config, "_pre_quantization_dtype"):
132
+ target_dtype = self.config._pre_quantization_dtype
133
+ else:
134
+ target_dtype = self.q_proj.weight.dtype
135
+
136
+ query_states = query_states.to(target_dtype)
137
+ key_states = key_states.to(target_dtype)
138
+ value_states = value_states.to(target_dtype)
139
+ dropout=self.dropout if self.training else 0.0
140
+ attn_output = flash_attn_func(
141
+ query_states, key_states, value_states, dropout, softmax_scale=None, causal=self.is_causal
142
+ )
143
+
144
+ attn_output = attn_output.reshape(bsz, tgt_len, -1)
145
+ attn_output = self.out_proj(attn_output)
146
+
147
+ if not output_attentions:
148
+ attn_weights = None
149
+
150
+ return attn_output, attn_weights, None
151
+
152
+
153
+
154
 
155
  class AeroAudioMultiModalProjector(nn.Module):
156
  def __init__(self, config: AeroConfig):
 
209
  class AeroForConditionalGeneration(AeroPreTrainedModel, GenerationMixin):
210
  def __init__(self, config: AeroConfig):
211
  super().__init__(config)
212
+ if config._attn_implementation == "flash_attention_2":
213
+ Qwen2AudioFlashAttention2.forward = qwen2_audio_flash_attn_forward
214
  self.audio_tower_type = config.audio_config.model_type
215
  self.audio_tower = AutoModel.from_config(config.audio_config)
216
  self.audio_modal_projector = AeroAudioMultiModalProjector(config)